Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c85da3a
Unverified
Commit
6c85da3a
authored
Feb 27, 2025
by
Roger Wang
Committed by
GitHub
Feb 27, 2025
Browse files
[V1]`SupportsV0Only` protocol for model definitions (#13959)
Signed-off-by:
Roger Wang
<
ywang@roblox.com
>
parent
67fc4268
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
93 additions
and
32 deletions
+93
-32
vllm/config.py
vllm/config.py
+5
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+5
-2
vllm/model_executor/models/bamba.py
vllm/model_executor/models/bamba.py
+3
-2
vllm/model_executor/models/bart.py
vllm/model_executor/models/bart.py
+2
-1
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+2
-2
vllm/model_executor/models/florence2.py
vllm/model_executor/models/florence2.py
+2
-2
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+3
-1
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+26
-0
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+3
-2
vllm/model_executor/models/mamba.py
vllm/model_executor/models/mamba.py
+4
-2
vllm/model_executor/models/mamba2.py
vllm/model_executor/models/mamba2.py
+4
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+4
-2
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+3
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+2
-2
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+4
-2
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+3
-2
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+12
-2
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+3
-2
vllm/model_executor/models/whisper.py
vllm/model_executor/models/whisper.py
+3
-2
No files found.
vllm/config.py
View file @
6c85da3a
...
...
@@ -1039,6 +1039,11 @@ class ModelConfig:
def
runner_type
(
self
)
->
RunnerType
:
return
_TASK_RUNNER
[
self
.
task
]
@
property
def
is_v1_compatible
(
self
)
->
bool
:
architectures
=
getattr
(
self
.
hf_config
,
"architectures"
,
[])
return
ModelRegistry
.
is_v1_compatible
(
architectures
)
class
CacheConfig
:
"""Configuration for the KV cache.
...
...
vllm/model_executor/models/__init__.py
View file @
6c85da3a
# SPDX-License-Identifier: Apache-2.0
from
.interfaces
import
(
HasInnerState
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
has_inner_state
,
supports_lora
,
supports_multimodal
,
supports_pp
)
SupportsPP
,
SupportsV0Only
,
has_inner_state
,
supports_lora
,
supports_multimodal
,
supports_pp
,
supports_v0_only
)
from
.interfaces_base
import
(
VllmModelForPooling
,
VllmModelForTextGeneration
,
is_pooling_model
,
is_text_generation_model
)
from
.registry
import
ModelRegistry
...
...
@@ -21,4 +22,6 @@ __all__ = [
"supports_multimodal"
,
"SupportsPP"
,
"supports_pp"
,
"SupportsV0Only"
,
"supports_v0_only"
,
]
vllm/model_executor/models/bamba.py
View file @
6c85da3a
...
...
@@ -32,7 +32,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
SupportsV0Only
)
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -366,7 +367,7 @@ class BambaModel(nn.Module):
class
BambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
,
SupportsPP
,
IsHybrid
):
IsHybrid
,
SupportsV0Only
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/bart.py
View file @
6c85da3a
...
...
@@ -43,6 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsV0Only
from
.utils
import
maybe_prefix
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -776,7 +777,7 @@ class BartModel(nn.Module):
return
decoder_outputs
class
BartForConditionalGeneration
(
nn
.
Module
):
class
BartForConditionalGeneration
(
nn
.
Module
,
SupportsV0Only
):
base_model_prefix
=
"model"
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
vllm/model_executor/models/bert.py
View file @
6c85da3a
...
...
@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from
vllm.transformers_utils.config
import
(
get_cross_encoder_activation_function
)
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.utils
import
WeightsMapper
,
maybe_prefix
...
...
@@ -385,7 +385,7 @@ class BertModel(nn.Module):
return
loaded_params
class
BertEmbeddingModel
(
nn
.
Module
):
class
BertEmbeddingModel
(
nn
.
Module
,
SupportsV0Only
):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
...
...
vllm/model_executor/models/florence2.py
View file @
6c85da3a
...
...
@@ -29,7 +29,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
,
SupportsV0Only
from
.utils
import
AutoWeightsLoader
,
flatten_bn
,
merge_multimodal_embeddings
...
...
@@ -651,7 +651,7 @@ class Florence2LanguageModel(nn.Module):
return
decoder_outputs
class
Florence2LanguageForConditionalGeneration
(
nn
.
Module
):
class
Florence2LanguageForConditionalGeneration
(
nn
.
Module
,
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
()
...
...
vllm/model_executor/models/gritlm.py
View file @
6c85da3a
...
...
@@ -19,6 +19,8 @@ from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput
)
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
.interfaces
import
SupportsV0Only
logger
=
init_logger
(
__name__
)
...
...
@@ -177,7 +179,7 @@ class GritLMPooler(nn.Module):
return
PoolerOutput
(
outputs
=
pooled_outputs
)
class
GritLM
(
LlamaForCausalLM
):
class
GritLM
(
LlamaForCausalLM
,
SupportsV0Only
):
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
The class inherits from LlamaForCausalLM and provides a custom pooling
...
...
vllm/model_executor/models/interfaces.py
View file @
6c85da3a
...
...
@@ -498,3 +498,29 @@ def supports_transcription(
return
isinstance
(
model
,
SupportsTranscription
)
return
isinstance
(
model
,
SupportsTranscription
)
@
runtime_checkable
class
SupportsV0Only
(
Protocol
):
"""Models with this interface are not compatible with V1 vLLM."""
supports_v0_only
:
ClassVar
[
Literal
[
True
]]
=
True
@
overload
def
supports_v0_only
(
model
:
Type
[
object
])
->
TypeIs
[
Type
[
SupportsV0Only
]]:
...
@
overload
def
supports_v0_only
(
model
:
object
)
->
TypeIs
[
SupportsV0Only
]:
...
def
supports_v0_only
(
model
:
Union
[
Type
[
object
],
object
],
)
->
Union
[
TypeIs
[
Type
[
SupportsV0Only
]],
TypeIs
[
SupportsV0Only
]]:
if
isinstance
(
model
,
type
):
return
isinstance
(
model
,
SupportsV0Only
)
return
isinstance
(
model
,
SupportsV0Only
)
vllm/model_executor/models/jamba.py
View file @
6c85da3a
...
...
@@ -30,7 +30,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
vllm.utils
import
LayerBlockType
from
.interfaces
import
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
from
.interfaces
import
(
HasInnerState
,
IsHybrid
,
SupportsLoRA
,
SupportsPP
,
SupportsV0Only
)
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -353,7 +354,7 @@ class JambaModel(nn.Module):
class
JambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
SupportsLoRA
,
SupportsPP
,
IsHybrid
):
IsHybrid
,
SupportsV0Only
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/mamba.py
View file @
6c85da3a
...
...
@@ -19,7 +19,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
,
SupportsPP
)
IsAttentionFree
,
SupportsPP
,
SupportsV0Only
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -155,7 +156,8 @@ class MambaModel(nn.Module):
return
hidden_states
class
MambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
,
SupportsPP
):
class
MambaForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
,
SupportsPP
,
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
...
...
vllm/model_executor/models/mamba2.py
View file @
6c85da3a
...
...
@@ -22,7 +22,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
HasInnerState
,
IsAttentionFree
)
IsAttentionFree
,
SupportsV0Only
)
from
vllm.model_executor.models.mamba_cache
import
(
MambaCacheManager
,
MambaCacheParams
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
...
@@ -174,7 +175,8 @@ class Mamba2Model(nn.Module):
return
hidden_states
class
Mamba2ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
):
class
Mamba2ForCausalLM
(
nn
.
Module
,
HasInnerState
,
IsAttentionFree
,
SupportsV0Only
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
config
=
vllm_config
.
model_config
.
hf_config
...
...
vllm/model_executor/models/minicpmv.py
View file @
6c85da3a
...
...
@@ -63,7 +63,8 @@ from vllm.platforms import current_platform
from
vllm.sequence
import
IntermediateTensors
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.interfaces
import
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
(
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
SupportsV0Only
)
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
CPU_DEVICE
=
torch
.
device
(
"cpu"
)
...
...
@@ -804,7 +805,8 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
return
result
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
MiniCPMVBaseModel
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsV0Only
):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
...
...
vllm/model_executor/models/mllama.py
View file @
6c85da3a
...
...
@@ -63,7 +63,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
.clip
import
CLIPMLP
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
,
SupportsV0Only
from
.llama
import
LlamaDecoderLayer
,
LlamaMLP
from
.utils
import
maybe_prefix
...
...
@@ -1128,7 +1128,8 @@ class MllamaForCausalLM(nn.Module):
@
MULTIMODAL_REGISTRY
.
register_processor
(
MllamaMultiModalProcessor
,
info
=
MllamaProcessingInfo
,
dummy_inputs
=
MllamaDummyInputsBuilder
)
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
MllamaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsV0Only
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
...
...
vllm/model_executor/models/paligemma.py
View file @
6c85da3a
...
...
@@ -18,7 +18,7 @@ from vllm.multimodal.inputs import NestedTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
SupportsMultiModal
,
SupportsPP
,
SupportsV0Only
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
(
AutoWeightsLoader
,
init_vllm_registered_model
,
...
...
@@ -136,7 +136,7 @@ class PaliGemmaMultiModalProjector(nn.Module):
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_paligemma
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_paligemma
)
class
PaliGemmaForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
SupportsPP
,
SupportsV0Only
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
6c85da3a
...
...
@@ -25,7 +25,8 @@ from transformers import BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
SupportsMultiModal
)
SupportsMultiModal
,
SupportsV0Only
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
...
...
@@ -111,7 +112,8 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
PrithviGeoSpatialMAEMultiModalProcessor
,
info
=
PrithviGeoSpatialMAEProcessingInfo
,
dummy_inputs
=
PrithviGeoSpatialMAEInputBuilder
)
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
):
class
PrithviGeoSpatialMAE
(
nn
.
Module
,
IsAttentionFree
,
SupportsMultiModal
,
SupportsV0Only
):
""" Prithvi Masked Autoencoder"""
def
_instantiate_model
(
self
,
config
:
dict
)
->
Optional
[
nn
.
Module
]:
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
6c85da3a
...
...
@@ -17,7 +17,7 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
,
PoolerOutput
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
SupportsV0Only
from
.qwen2
import
Qwen2Model
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
...
...
@@ -33,7 +33,8 @@ class ReLU(nn.Module):
return
self
.
activation
(
input
)
class
Qwen2RewardBaseModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
):
class
Qwen2RewardBaseModel
(
nn
.
Module
,
SupportsLoRA
,
SupportsPP
,
SupportsV0Only
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
...
...
vllm/model_executor/models/registry.py
View file @
6c85da3a
...
...
@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from
.interfaces
import
(
has_inner_state
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
supports_multimodal
,
supports_pp
,
supports_transcription
)
supports_pp
,
supports_transcription
,
supports_v0_only
)
from
.interfaces_base
import
is_text_generation_model
logger
=
init_logger
(
__name__
)
...
...
@@ -228,6 +228,7 @@ class _ModelInfo:
is_attention_free
:
bool
is_hybrid
:
bool
supports_transcription
:
bool
supports_v0_only
:
bool
@
staticmethod
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
...
...
@@ -241,7 +242,9 @@ class _ModelInfo:
has_inner_state
=
has_inner_state
(
model
),
is_attention_free
=
is_attention_free
(
model
),
is_hybrid
=
is_hybrid
(
model
),
supports_transcription
=
supports_transcription
(
model
))
supports_transcription
=
supports_transcription
(
model
),
supports_v0_only
=
supports_v0_only
(
model
),
)
class
_BaseRegisteredModel
(
ABC
):
...
...
@@ -504,6 +507,13 @@ class _ModelRegistry:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
model_cls
.
supports_transcription
def
is_v1_compatible
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
bool
:
model_cls
,
_
=
self
.
inspect_model_cls
(
architectures
)
return
not
model_cls
.
supports_v0_only
ModelRegistry
=
_ModelRegistry
({
model_arch
:
...
...
vllm/model_executor/models/roberta.py
View file @
6c85da3a
...
...
@@ -19,7 +19,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from
vllm.transformers_utils.config
import
(
get_cross_encoder_activation_function
)
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
def
roberta_task_weights_filter
(
...
...
@@ -191,7 +191,8 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
assert
len
(
loaded
),
"Unable to load RobertaEmbeddingModel"
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
SupportsV0Only
):
"""A model that uses Roberta to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
...
...
vllm/model_executor/models/whisper.py
View file @
6c85da3a
...
...
@@ -34,7 +34,8 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
.interfaces
import
SupportsMultiModal
,
SupportsTranscription
from
.interfaces
import
(
SupportsMultiModal
,
SupportsTranscription
,
SupportsV0Only
)
from
.utils
import
(
AutoWeightsLoader
,
WeightsMapper
,
cast_overflow_tensors
,
make_layers
)
...
...
@@ -643,7 +644,7 @@ class WhisperMultiModalProcessor(
info
=
WhisperProcessingInfo
,
dummy_inputs
=
WhisperDummyInputsBuilder
)
class
WhisperForConditionalGeneration
(
nn
.
Module
,
SupportsTranscription
,
SupportsMultiModal
):
SupportsMultiModal
,
SupportsV0Only
):
packed_modules_mapping
=
{
"self_attn.qkv_proj"
:
[
"self_attn.q_proj"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment