Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5eeef1b9
Unverified
Commit
5eeef1b9
authored
Aug 27, 2025
by
Cyrus Leung
Committed by
GitHub
Aug 27, 2025
Browse files
[Model] Explicit `default_pooling_type` interface (#23736)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
704432af
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
51 additions
and
33 deletions
+51
-33
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+2
-2
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+3
-2
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+1
-1
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+1
-18
vllm/model_executor/models/interfaces_base.py
vllm/model_executor/models/interfaces_base.py
+28
-0
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+2
-1
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+2
-1
vllm/model_executor/models/prithvi_geospatial_mae.py
vllm/model_executor/models/prithvi_geospatial_mae.py
+4
-3
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+2
-1
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+4
-3
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+2
-1
No files found.
vllm/model_executor/models/bert.py
View file @
5eeef1b9
...
...
@@ -28,8 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
.interfaces
import
(
SupportsCrossEncoding
,
SupportsQuant
,
default_pooling_type
)
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.interfaces_base
import
default_pooling_type
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
5eeef1b9
...
...
@@ -27,13 +27,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
SupportsQuant
,
default_pooling_type
)
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsQuant
from
.interfaces_base
import
default_pooling_type
class
BertWithRopeEmbedding
(
nn
.
Module
):
...
...
vllm/model_executor/models/gritlm.py
View file @
5eeef1b9
...
...
@@ -20,7 +20,7 @@ from vllm.sequence import PoolerOutput
from
vllm.tasks
import
PoolingTask
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
.interfaces
import
default_pooling_type
from
.interfaces
_base
import
default_pooling_type
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/interfaces.py
View file @
5eeef1b9
...
...
@@ -3,7 +3,7 @@
from
collections.abc
import
Iterable
,
Mapping
,
MutableSequence
from
typing
import
(
TYPE_CHECKING
,
ClassVar
,
Literal
,
Optional
,
Protocol
,
TypeVar
,
Union
,
overload
,
runtime_checkable
)
Union
,
overload
,
runtime_checkable
)
import
numpy
as
np
import
torch
...
...
@@ -641,23 +641,6 @@ def supports_cross_encoding(
return
is_pooling_model
(
model
)
and
_supports_cross_encoding
(
model
)
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
torch
.
nn
.
Module
])
def
default_pooling_type
(
pooling_type
:
str
):
"""Set default_pooling_type decorator. """
def
func
(
model
:
_T
)
->
_T
:
model
.
default_pooling_type
=
pooling_type
# type: ignore
return
model
return
func
def
get_default_pooling_type
(
model
:
Union
[
type
[
object
],
object
])
->
str
:
return
getattr
(
model
,
"default_pooling_type"
,
"LAST"
)
class
SupportsQuant
:
"""The interface required for all models that support quantization."""
...
...
vllm/model_executor/models/interfaces_base.py
View file @
5eeef1b9
...
...
@@ -144,6 +144,17 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class.
"""
default_pooling_type
:
ClassVar
[
str
]
=
"LAST"
"""
Indicates the
[vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
decorator to conveniently set this field.
"""
pooler
:
Pooler
"""The pooler is only called on TP rank 0."""
...
...
@@ -165,3 +176,20 @@ def is_pooling_model(
return
False
return
getattr
(
model
,
"is_pooling_model"
,
False
)
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
def
default_pooling_type
(
pooling_type
:
str
):
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
def
func
(
model
:
_T
)
->
_T
:
model
.
default_pooling_type
=
pooling_type
# type: ignore
return
model
return
func
def
get_default_pooling_type
(
model
:
Union
[
type
[
object
],
object
])
->
str
:
return
getattr
(
model
,
"default_pooling_type"
,
"LAST"
)
vllm/model_executor/models/internlm2.py
View file @
5eeef1b9
...
...
@@ -31,7 +31,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
default_pooling_type
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces_base
import
default_pooling_type
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
vllm/model_executor/models/modernbert.py
View file @
5eeef1b9
...
...
@@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
.interfaces
import
SupportsCrossEncoding
,
default_pooling_type
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces_base
import
default_pooling_type
from
.utils
import
WeightsMapper
,
maybe_prefix
...
...
vllm/model_executor/models/prithvi_geospatial_mae.py
View file @
5eeef1b9
...
...
@@ -27,9 +27,6 @@ from transformers import BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
(
IsAttentionFree
,
MultiModalEmbeddings
,
SupportsMultiModalWithRawInput
,
default_pooling_type
)
from
vllm.model_executor.models.utils
import
AutoWeightsLoader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
...
...
@@ -43,6 +40,10 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
(
IsAttentionFree
,
MultiModalEmbeddings
,
SupportsMultiModalWithRawInput
)
from
.interfaces_base
import
default_pooling_type
def
_prithvi_field_config
(
hf_inputs
:
Mapping
[
str
,
torch
.
Tensor
]):
# This model receives in input a multi-dimensional tensor representing
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
5eeef1b9
...
...
@@ -18,7 +18,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
default_pooling_type
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces_base
import
default_pooling_type
from
.qwen2
import
Qwen2Model
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
...
...
vllm/model_executor/models/registry.py
View file @
5eeef1b9
...
...
@@ -25,11 +25,12 @@ from vllm.logger import init_logger
from
vllm.transformers_utils.dynamic_module
import
(
try_get_class_from_dynamic_module
)
from
.interfaces
import
(
get_default_pooling_type
,
has_inner_state
,
has_noops
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
from
.interfaces
import
(
has_inner_state
,
has_noops
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
supports_multimodal
,
supports_multimodal_raw_input
,
supports_pp
,
supports_transcription
,
supports_v0_only
)
from
.interfaces_base
import
is_pooling_model
,
is_text_generation_model
from
.interfaces_base
import
(
get_default_pooling_type
,
is_pooling_model
,
is_text_generation_model
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/models/roberta.py
View file @
5eeef1b9
...
...
@@ -22,7 +22,8 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from
vllm.sequence
import
IntermediateTensors
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.interfaces
import
SupportsCrossEncoding
,
default_pooling_type
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces_base
import
default_pooling_type
class
RobertaEmbedding
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment