Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
583a90e0
Unverified
Commit
583a90e0
authored
Jan 10, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 10, 2026
Browse files
[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
52d42829
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
195 additions
and
100 deletions
+195
-100
vllm/config/model.py
vllm/config/model.py
+55
-39
vllm/config/pooler.py
vllm/config/pooler.py
+59
-7
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-1
vllm/model_executor/layers/pooler/seqwise/methods.py
vllm/model_executor/layers/pooler/seqwise/methods.py
+4
-4
vllm/model_executor/layers/pooler/seqwise/poolers.py
vllm/model_executor/layers/pooler/seqwise/poolers.py
+2
-2
vllm/model_executor/layers/pooler/tokwise/methods.py
vllm/model_executor/layers/pooler/tokwise/methods.py
+2
-4
vllm/model_executor/layers/pooler/tokwise/poolers.py
vllm/model_executor/layers/pooler/tokwise/poolers.py
+2
-2
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+5
-5
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+2
-2
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+3
-3
vllm/model_executor/models/config.py
vllm/model_executor/models/config.py
+5
-4
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+1
-1
vllm/model_executor/models/interfaces_base.py
vllm/model_executor/models/interfaces_base.py
+33
-9
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+1
-1
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+3
-3
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+2
-2
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+9
-5
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+2
-2
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+3
-3
vllm/pooling_params.py
vllm/pooling_params.py
+1
-1
No files found.
vllm/config/model.py
View file @
583a90e0
...
...
@@ -539,9 +539,12 @@ class ModelConfig:
if
getattr
(
self
.
pooler_config
,
k
)
is
None
:
setattr
(
self
.
pooler_config
,
k
,
v
)
default_pooling_type
=
self
.
_model_info
.
default_pooling_type
if
self
.
pooler_config
.
pooling_type
is
None
:
self
.
pooler_config
.
pooling_type
=
default_pooling_type
default_seq_pooling_type
=
self
.
_model_info
.
default_seq_pooling_type
if
self
.
pooler_config
.
seq_pooling_type
is
None
:
self
.
pooler_config
.
seq_pooling_type
=
default_seq_pooling_type
default_tok_pooling_type
=
self
.
_model_info
.
default_tok_pooling_type
if
self
.
pooler_config
.
tok_pooling_type
is
None
:
self
.
pooler_config
.
tok_pooling_type
=
default_tok_pooling_type
self
.
dtype
:
torch
.
dtype
=
_get_and_verify_dtype
(
self
.
model
,
...
...
@@ -1543,8 +1546,8 @@ class ModelConfig:
@
property
def
attn_type
(
self
)
->
AttnTypeStr
:
if
self
.
pooler_config
is
not
None
:
pooling_type
=
self
.
_model_info
.
default_pooling_type
.
lower
()
if
pooling_type
==
"
cls
"
:
seq_
pooling_type
=
self
.
_model_info
.
default_
seq_
pooling_type
if
seq_
pooling_type
==
"
CLS
"
:
return
"encoder_only"
else
:
is_causal
=
getattr
(
self
.
hf_config
,
"is_causal"
,
True
)
...
...
@@ -1561,89 +1564,102 @@ class ModelConfig:
@
property
def
is_chunked_prefill_supported
(
self
)
->
bool
:
attn_type
=
self
.
attn_type
if
self
.
pooler_config
is
not
None
:
if
pooler_config
:
=
self
.
pooler_config
:
# for pooling models
if
attn_type
==
"encoder_only"
:
logger
.
debug
(
"Pooling models with bidirectional attn
does not support
"
"chunked prefill."
"Pooling models with bidirectional attn "
"
do not support
chunked prefill."
)
return
False
elif
attn_type
==
"decoder"
:
pooling_type
=
self
.
pooler_config
.
pooling_type
.
lower
()
if
pooling_type
in
[
"mean"
,
"step"
,
"cls"
]:
if
attn_type
==
"decoder"
:
if
(
pooler_config
.
seq_pooling_type
in
(
"MEAN"
,
"CLS"
)
or
pooler_config
.
tok_pooling_type
==
"STEP"
):
logger
.
debug
(
"Pooling models with %s pooling does not "
"support chunked prefill."
,
pooling_type
,
"Pooling models with causal attn and %s/%s pooling "
"do not support chunked prefill."
,
pooler_config
.
seq_pooling_type
,
pooler_config
.
tok_pooling_type
,
)
return
False
el
if
pooling_type
in
[
"all"
,
"last"
]
:
el
se
:
logger
.
debug
(
"Pooling models with causal attn and %s pooling support "
"chunked prefill."
,
pooling_type
,
"Pooling models with causal attn and %s/%s pooling "
"support chunked prefill."
,
pooler_config
.
seq_pooling_type
,
pooler_config
.
tok_pooling_type
,
)
return
True
else
:
raise
ValueError
(
f
"
{
pooling_type
=
}
not supported."
)
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return
attn_type
!=
"encoder_decoder"
else
:
# for generative models
if
attn_type
==
"encoder_decoder"
:
logger
.
debug
(
"Encoder decoder models do
es
not support chunked prefill."
)
logger
.
debug
(
"Encoder decoder models do not support chunked prefill."
)
return
False
logger
.
debug
(
"Generative models support chunked prefill."
)
return
True
@
property
def
is_prefix_caching_supported
(
self
)
->
bool
:
attn_type
=
self
.
attn_type
if
self
.
pooler_config
is
not
None
:
if
pooler_config
:
=
self
.
pooler_config
:
# for pooling models
if
attn_type
==
"encoder_only"
:
logger
.
debug
(
"Pooling models with bidirectional attn
does not
"
"support prefix caching."
"Pooling models with bidirectional attn "
"
do not
support prefix caching."
)
return
False
elif
attn_type
==
"decoder"
:
pooling_type
=
self
.
pooler_config
.
pooling_type
.
lower
()
if
pooling_type
in
[
"mean"
,
"step"
,
"cls"
]:
if
attn_type
==
"decoder"
:
if
(
pooler_config
.
seq_pooling_type
in
(
"MEAN"
,
"CLS"
)
or
pooler_config
.
tok_pooling_type
==
"STEP"
):
logger
.
debug
(
"Pooling models with %s pooling does not "
"support prefix caching."
,
pooling_type
,
"Pooling models with causal attn and %s/%s pooling "
"do not support prefix caching."
,
pooler_config
.
seq_pooling_type
,
pooler_config
.
tok_pooling_type
,
)
return
False
el
if
pooling_type
in
[
"all"
,
"last"
]
:
el
se
:
logger
.
debug
(
"Pooling models with causal attn and %s pooling support "
"prefix caching."
,
pooling_type
,
"Pooling models with causal attn and %s/%s pooling "
"support prefix caching."
,
pooler_config
.
seq_pooling_type
,
pooler_config
.
tok_pooling_type
,
)
return
True
else
:
raise
ValueError
(
f
"
{
pooling_type
=
}
not supported."
)
# vllm currently does not have pooling models using hybrid,
# attention_free or encoder_decoder attn types.
return
False
else
:
# for generative models
if
attn_type
==
"hybrid"
:
logger
.
debug
(
"Hybrid models do
es
not support prefix caching since the feature "
"Hybrid models do not support prefix caching since the feature "
"is still experimental."
)
return
False
elif
attn_type
==
"attention_free"
:
logger
.
debug
(
"Attention free models do
es
not support prefix caching since the "
"Attention free models do not support prefix caching since the "
"feature is still experimental."
)
return
False
elif
attn_type
==
"encoder_decoder"
:
logger
.
debug
(
"Encoder decoder models do
es
not support prefix caching."
)
logger
.
debug
(
"Encoder decoder models do not support prefix caching."
)
return
False
else
:
# attn_type == "decoder"
logger
.
debug
(
"Generative models support prefix caching."
)
...
...
vllm/config/pooler.py
View file @
583a90e0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
,
Literal
from
typing
import
Any
,
Literal
,
get_args
from
pydantic.dataclasses
import
dataclass
...
...
@@ -11,7 +11,11 @@ from vllm.utils.hashing import safe_hash
logger
=
init_logger
(
__name__
)
PoolingTypeStr
=
Literal
[
"LAST"
,
"ALL"
,
"CLS"
,
"STEP"
,
"MEAN"
]
SequencePoolingType
=
Literal
[
"CLS"
,
"LAST"
,
"MEAN"
]
SEQ_POOLING_TYPES
:
tuple
[
SequencePoolingType
,
...]
=
get_args
(
SequencePoolingType
)
TokenPoolingType
=
Literal
[
"ALL"
,
"STEP"
]
TOK_POOLING_TYPES
:
tuple
[
TokenPoolingType
,
...]
=
get_args
(
TokenPoolingType
)
@
config
...
...
@@ -19,9 +23,26 @@ PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]
class
PoolerConfig
:
"""Controls the behavior of output pooling in pooling models."""
pooling_type
:
PoolingTypeStr
|
None
=
None
pooling_type
:
SequencePoolingType
|
TokenPoolingType
|
None
=
None
"""
The pooling method used for pooling.
If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
with this field. Alternatively, users can set `seq_pooling_type` and
`tok_pooling_type` explicitly.
This field is mainly for user convenience. Internal code should always use
`seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
"""
seq_pooling_type
:
SequencePoolingType
|
None
=
None
"""
The pooling method used for sequence pooling.
"""
tok_pooling_type
:
TokenPoolingType
|
None
=
None
"""
The pooling method
of th
e pooling
model
.
The pooling method
used for tokenwis
e pooling.
"""
## for embeddings models
...
...
@@ -88,9 +109,40 @@ class PoolerConfig:
# raise deprecated warning for softmax and activation
self
.
use_activation
=
get_use_activation
(
self
)
def
get_pooling_type
(
self
)
->
PoolingTypeStr
:
assert
self
.
pooling_type
is
not
None
,
"Should be resolved by ModelConfig"
return
self
.
pooling_type
if
pooling_type
:
=
self
.
pooling_type
:
if
self
.
seq_pooling_type
is
not
None
:
raise
ValueError
(
"Cannot set both `pooling_type` and `seq_pooling_type`"
)
if
self
.
tok_pooling_type
is
not
None
:
raise
ValueError
(
"Cannot set both `pooling_type` and `tok_pooling_type`"
)
if
pooling_type
in
SEQ_POOLING_TYPES
:
logger
.
debug
(
"Resolved `pooling_type=%r` to `seq_pooling_type=%r`."
,
pooling_type
,
pooling_type
,
)
self
.
seq_pooling_type
=
pooling_type
elif
pooling_type
in
TOK_POOLING_TYPES
:
logger
.
debug
(
"Resolved `pooling_type=%r` to `tok_pooling_type=%r`."
,
pooling_type
,
pooling_type
,
)
self
.
tok_pooling_type
=
pooling_type
else
:
raise
NotImplementedError
(
pooling_type
)
def
get_seq_pooling_type
(
self
)
->
SequencePoolingType
:
assert
self
.
seq_pooling_type
is
not
None
,
"Should be resolved by ModelConfig"
return
self
.
seq_pooling_type
def
get_tok_pooling_type
(
self
)
->
TokenPoolingType
:
assert
self
.
tok_pooling_type
is
not
None
,
"Should be resolved by ModelConfig"
return
self
.
tok_pooling_type
def
compute_hash
(
self
)
->
str
:
"""
...
...
vllm/entrypoints/llm.py
View file @
583a90e0
...
...
@@ -172,7 +172,7 @@ class LLM:
The available overrides depend on the model that is being run.
For example, for Phi-3-Vision: `{"num_crops": 4}`.
pooler_config: Initialize non-default pooling config for the pooling
model. e.g. `PoolerConfig(pooling_type="
mean
", normalize=False)`.
model. e.g. `PoolerConfig(
seq_
pooling_type="
MEAN
", normalize=False)`.
compilation_config: Either an integer or a dictionary. If it is an
integer, it is used as the mode of compilation optimization. If it
is a dictionary, it can specify the full compilation configuration.
...
...
vllm/model_executor/layers/pooler/seqwise/methods.py
View file @
583a90e0
...
...
@@ -7,7 +7,7 @@ from typing import TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config.pooler
import
PoolingType
Str
from
vllm.config.pooler
import
Sequence
PoolingType
from
vllm.model_executor.layers.pooler
import
PoolingParamsUpdate
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
...
@@ -82,11 +82,11 @@ class MeanPool(SequencePoolingMethod):
)
/
prompt_lens
.
unsqueeze
(
1
)
def
get_seq_pooling_method
(
pooling_type
:
PoolingTypeStr
|
str
):
if
pooling_type
==
"LAST"
:
return
LastPool
()
def
get_seq_pooling_method
(
pooling_type
:
SequencePoolingType
|
str
):
if
pooling_type
==
"CLS"
:
return
CLSPool
()
if
pooling_type
==
"LAST"
:
return
LastPool
()
if
pooling_type
==
"MEAN"
:
return
MeanPool
()
...
...
vllm/model_executor/layers/pooler/seqwise/poolers.py
View file @
583a90e0
...
...
@@ -85,7 +85,7 @@ class SequencePooler(Pooler):
def
pooler_for_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_pooling_type
())
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_
seq_
pooling_type
())
head
=
EmbeddingPoolerHead
()
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
...
...
@@ -99,7 +99,7 @@ def pooler_for_classify(
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
pooling
is
None
:
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_pooling_type
())
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_
seq_
pooling_type
())
head
=
ClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
...
...
vllm/model_executor/layers/pooler/tokwise/methods.py
View file @
583a90e0
...
...
@@ -8,7 +8,7 @@ import torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.config.pooler
import
PoolingType
Str
from
vllm.config.pooler
import
Token
PoolingType
from
vllm.model_executor.layers.pooler
import
PoolingParamsUpdate
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
...
...
@@ -113,12 +113,10 @@ class StepPool(AllPool):
return
pooled_data
def
get_tok_pooling_method
(
pooling_type
:
PoolingType
Str
|
str
):
def
get_tok_pooling_method
(
pooling_type
:
Token
PoolingType
|
str
):
if
pooling_type
==
"ALL"
:
return
AllPool
()
if
pooling_type
==
"STEP"
:
return
StepPool
()
# TODO: Separate seq and tok pooling types so we don't need this fallback
return
AllPool
()
raise
NotImplementedError
(
f
"Unknown tokenwise pooling type:
{
pooling_type
!
r
}
"
)
vllm/model_executor/layers/pooler/tokwise/poolers.py
View file @
583a90e0
...
...
@@ -85,7 +85,7 @@ class TokenPooler(Pooler):
def
pooler_for_token_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_pooling_type
())
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_
tok_
pooling_type
())
head
=
TokenEmbeddingPoolerHead
()
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
...
...
@@ -99,7 +99,7 @@ def pooler_for_token_classify(
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
pooling
is
None
:
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_pooling_type
())
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_
tok_
pooling_type
())
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
...
...
vllm/model_executor/models/bert.py
View file @
583a90e0
...
...
@@ -357,7 +357,7 @@ class BertOutput(nn.Module):
@
support_torch_compile
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
BertModel
(
nn
.
Module
,
SupportsQuant
):
is_pooling_model
=
True
...
...
@@ -461,7 +461,7 @@ class BertPoolingModel(BertModel):
return
loaded_params
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
BertEmbeddingModel
(
nn
.
Module
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
...
...
@@ -675,7 +675,7 @@ class SPLADESparsePooler(Pooler):
return
torch
.
stack
(
pooled_list
,
dim
=
0
).
contiguous
()
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
BertSpladeSparseEmbeddingModel
(
BertEmbeddingModel
):
"""
BertEmbeddingModel + SPLADE sparse embedding.
...
...
@@ -780,7 +780,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return
loaded
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
...
...
@@ -849,7 +849,7 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
@
attn_type
(
"encoder_only"
)
@
default_pooling_type
(
"ALL"
)
@
default_pooling_type
(
tok_pooling_type
=
"ALL"
)
class
BertForTokenClassification
(
nn
.
Module
):
is_pooling_model
=
True
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
583a90e0
...
...
@@ -441,7 +441,7 @@ class BertWithRopeEncoder(nn.Module):
@
support_torch_compile
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
BertWithRope
(
nn
.
Module
,
SupportsQuant
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
...
...
@@ -670,7 +670,7 @@ class JinaRobertaModel(BertWithRope):
return
super
().
load_weights
(
weights
)
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
GteNewForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
is_pooling_model
=
True
...
...
vllm/model_executor/models/clip.py
View file @
583a90e0
...
...
@@ -145,7 +145,7 @@ class CLIPProcessingInfo(BaseProcessingInfo):
image_width
=
image_width
,
image_height
=
image_height
,
),
_get_vision_feature_select_strategy
(
pooler_config
.
pooling_type
),
_get_vision_feature_select_strategy
(
pooler_config
.
seq_
pooling_type
),
)
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
...
...
@@ -819,7 +819,7 @@ class CLIPVisionModel(nn.Module):
# Assume EOS token corresponds to LAST token in text model
@
default_pooling_type
(
"LAST"
)
@
default_pooling_type
(
seq_pooling_type
=
"LAST"
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
CLIPMultiModalProcessor
,
info
=
CLIPProcessingInfo
,
...
...
@@ -908,7 +908,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
)
->
torch
.
Tensor
:
if
feature_select_strategy
is
None
:
feature_select_strategy
=
_get_vision_feature_select_strategy
(
self
.
pooler_config
.
pooling_type
self
.
pooler_config
.
seq_
pooling_type
)
pooled_output
=
self
.
vision_model
(
...
...
vllm/model_executor/models/config.py
View file @
583a90e0
...
...
@@ -94,12 +94,12 @@ class JinaRobertaModelConfig(VerifyAndUpdateConfig):
class
LlamaBidirectionalConfig
(
VerifyAndUpdateConfig
):
@
staticmethod
def
verify_and_update_model_config
(
model_config
:
"ModelConfig"
)
->
None
:
from
vllm.config.pooler
import
PoolingType
Str
from
vllm.config.pooler
import
Sequence
PoolingType
hf_config
=
model_config
.
hf_config
hf_config
.
is_causal
=
False
pooling_type_map
:
dict
[
str
,
PoolingType
Str
]
=
{
pooling_type_map
:
dict
[
str
,
Sequence
PoolingType
]
=
{
"avg"
:
"MEAN"
,
"cls"
:
"CLS"
,
"last"
:
"LAST"
,
...
...
@@ -107,8 +107,9 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
pooling_type
=
pooling_type_map
.
get
(
hf_config
.
pooling
,
None
)
if
pooling_type
is
None
:
raise
ValueError
(
f
"pool_type
{
hf_config
.
pooling
}
not supported"
)
model_config
.
pooler_config
.
pooling_type
=
pooling_type
raise
ValueError
(
f
"pool_type
{
hf_config
.
pooling
!
r
}
not supported"
)
model_config
.
pooler_config
.
seq_pooling_type
=
pooling_type
class
NomicBertModelConfig
(
VerifyAndUpdateConfig
):
...
...
vllm/model_executor/models/gritlm.py
View file @
583a90e0
...
...
@@ -193,7 +193,7 @@ class GritLMPooler(SequencePooler):
return
self
.
activation
(
pooled_data
)
@
default_pooling_type
(
"MEAN"
)
@
default_pooling_type
(
seq_pooling_type
=
"MEAN"
)
class
GritLM
(
LlamaForCausalLM
):
"""This class implements the embedding model for parasail-ai/GritLM-7B-vllm.
...
...
vllm/model_executor/models/interfaces_base.py
View file @
583a90e0
...
...
@@ -20,12 +20,13 @@ from vllm.utils.func_utils import supports_kw
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config.model
import
AttnTypeStr
from
vllm.config.pooler
import
PoolingType
Str
from
vllm.config.pooler
import
Sequence
PoolingType
,
TokenPoolingType
from
vllm.model_executor.layers.pooler
import
Pooler
else
:
VllmConfig
=
Any
Pooler
=
Any
PoolingTypeStr
=
Any
SequencePoolingType
=
Any
TokenPoolingType
=
Any
AttnTypeStr
=
Any
logger
=
init_logger
(
__name__
)
...
...
@@ -155,9 +156,19 @@ class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
MRO of your model class.
"""
default_pooling_type
:
ClassVar
[
PoolingType
Str
]
=
"LAST"
default_
seq_
pooling_type
:
ClassVar
[
Sequence
PoolingType
]
=
"LAST"
"""
Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
Indicates the [vllm.config.pooler.PoolerConfig.seq_pooling_type][]
to use by default.
You can use the
[vllm.model_executor.models.interfaces_base.default_pooling_type][]
decorator to conveniently set this field.
"""
default_tok_pooling_type
:
ClassVar
[
TokenPoolingType
]
=
"ALL"
"""
Indicates the [vllm.config.pooler.PoolerConfig.tok_pooling_type][]
to use by default.
You can use the
...
...
@@ -200,18 +211,31 @@ def is_pooling_model(
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
def
default_pooling_type
(
pooling_type
:
PoolingTypeStr
):
"""Decorator to set `VllmModelForPooling.default_pooling_type`."""
def
default_pooling_type
(
*
,
seq_pooling_type
:
SequencePoolingType
=
"LAST"
,
tok_pooling_type
:
TokenPoolingType
=
"ALL"
,
):
"""Decorator to set `VllmModelForPooling.default_*_pooling_type`."""
def
func
(
model
:
_T
)
->
_T
:
model
.
default_pooling_type
=
pooling_type
# type: ignore
model
.
default_seq_pooling_type
=
seq_pooling_type
# type: ignore
model
.
default_tok_pooling_type
=
tok_pooling_type
# type: ignore
return
model
return
func
def
get_default_pooling_type
(
model
:
type
[
object
]
|
object
)
->
PoolingTypeStr
:
return
getattr
(
model
,
"default_pooling_type"
,
"LAST"
)
def
get_default_seq_pooling_type
(
model
:
type
[
object
]
|
object
,
)
->
SequencePoolingType
:
return
getattr
(
model
,
"default_seq_pooling_type"
,
"LAST"
)
def
get_default_tok_pooling_type
(
model
:
type
[
object
]
|
object
,
)
->
TokenPoolingType
:
return
getattr
(
model
,
"default_tok_pooling_type"
,
"ALL"
)
def
attn_type
(
attn_type
:
AttnTypeStr
):
...
...
vllm/model_executor/models/internlm2.py
View file @
583a90e0
...
...
@@ -402,7 +402,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return
loaded_params
@
default_pooling_type
(
"ALL"
)
@
default_pooling_type
(
tok_pooling_type
=
"ALL"
)
class
InternLM2ForRewardModel
(
InternLM2ForCausalLM
):
is_pooling_model
=
True
...
...
vllm/model_executor/models/modernbert.py
View file @
583a90e0
...
...
@@ -221,7 +221,7 @@ class ModernBertEncoderLayer(nn.Module):
@
support_torch_compile
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
ModernBertModel
(
nn
.
Module
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"layers."
:
"encoder_layer.layers."
}
...
...
@@ -308,7 +308,7 @@ class ModernBertPooler(SequencePooler):
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_data
)))
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
ModernBertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
is_pooling_model
=
True
...
...
@@ -395,7 +395,7 @@ class ModernBertPredictionHead(nn.Module):
@
attn_type
(
"encoder_only"
)
@
default_pooling_type
(
"ALL"
)
@
default_pooling_type
(
tok_pooling_type
=
"ALL"
)
class
ModernBertForTokenClassification
(
nn
.
Module
):
is_pooling_model
=
True
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
583a90e0
...
...
@@ -96,7 +96,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return
loader
.
load_weights
(
weights
)
@
default_pooling_type
(
"ALL"
)
@
default_pooling_type
(
tok_pooling_type
=
"ALL"
)
class
Qwen2ForRewardModel
(
Qwen2RewardBaseModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
vllm_config
.
model_config
.
hf_config
.
num_labels
=
1
...
...
@@ -108,7 +108,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
@
default_pooling_type
(
"STEP"
)
@
default_pooling_type
(
tok_pooling_type
=
"STEP"
)
class
Qwen2ForProcessRewardModel
(
Qwen2RewardBaseModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
vllm_config
.
model_config
.
hf_config
.
num_labels
=
2
...
...
vllm/model_executor/models/registry.py
View file @
583a90e0
...
...
@@ -35,10 +35,11 @@ from vllm.utils.hashing import safe_hash
if
TYPE_CHECKING
:
from
vllm.config.model
import
AttnTypeStr
from
vllm.config.pooler
import
PoolingType
Str
from
vllm.config.pooler
import
Sequence
PoolingType
,
TokenPoolingType
else
:
AttnTypeStr
=
Any
PoolingTypeStr
=
Any
SequencePoolingType
=
Any
TokenPoolingType
=
Any
from
.interfaces
import
(
...
...
@@ -57,7 +58,8 @@ from .interfaces import (
)
from
.interfaces_base
import
(
get_attn_type
,
get_default_pooling_type
,
get_default_seq_pooling_type
,
get_default_tok_pooling_type
,
is_pooling_model
,
is_text_generation_model
,
)
...
...
@@ -548,7 +550,8 @@ class _ModelInfo:
is_text_generation_model
:
bool
is_pooling_model
:
bool
attn_type
:
AttnTypeStr
default_pooling_type
:
PoolingTypeStr
default_seq_pooling_type
:
SequencePoolingType
default_tok_pooling_type
:
TokenPoolingType
supports_cross_encoding
:
bool
supports_multimodal
:
bool
supports_multimodal_raw_input_only
:
bool
...
...
@@ -569,7 +572,8 @@ class _ModelInfo:
architecture
=
model
.
__name__
,
is_text_generation_model
=
is_text_generation_model
(
model
),
is_pooling_model
=
is_pooling_model
(
model
),
default_pooling_type
=
get_default_pooling_type
(
model
),
default_seq_pooling_type
=
get_default_seq_pooling_type
(
model
),
default_tok_pooling_type
=
get_default_tok_pooling_type
(
model
),
attn_type
=
get_attn_type
(
model
),
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
...
...
vllm/model_executor/models/roberta.py
View file @
583a90e0
...
...
@@ -93,7 +93,7 @@ class RobertaClassificationHead(nn.Module):
return
x
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
RobertaEmbeddingModel
(
BertEmbeddingModel
):
"""A model that uses Roberta to provide embedding functionalities."""
...
...
@@ -150,7 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return
loader
.
load_weights
(
weights_list
,
mapper
=
mapper
)
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
"""A model that uses Roberta to provide embedding functionalities.
...
...
vllm/model_executor/models/siglip.py
View file @
583a90e0
...
...
@@ -129,7 +129,7 @@ class SiglipProcessingInfo(BaseProcessingInfo):
image_width
=
image_width
,
image_height
=
image_height
,
),
_get_vision_feature_select_strategy
(
pooler_config
.
pooling_type
),
_get_vision_feature_select_strategy
(
pooler_config
.
seq_
pooling_type
),
)
def
get_image_size_with_most_features
(
self
)
->
ImageSize
:
...
...
@@ -998,7 +998,7 @@ class SiglipTextEmbeddings(nn.Module):
# Assume EOS token corresponds to CLS token in text model
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
seq_pooling_type
=
"CLS"
)
@
MULTIMODAL_REGISTRY
.
register_processor
(
SiglipMultiModalProcessor
,
info
=
SiglipProcessingInfo
,
...
...
@@ -1125,7 +1125,7 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
)
->
torch
.
Tensor
:
if
feature_select_strategy
is
None
:
feature_select_strategy
=
_get_vision_feature_select_strategy
(
self
.
pooler_config
.
pooling_type
self
.
pooler_config
.
seq_
pooling_type
)
pooled_output
=
self
.
vision_model
(
...
...
vllm/pooling_params.py
View file @
583a90e0
...
...
@@ -140,7 +140,7 @@ class PoolingParams(
self
,
pooler_config
:
"PoolerConfig"
,
valid_parameters
:
list
[
str
]
):
step_pooling_parameters
=
[
"step_tag_id"
,
"returned_token_ids"
]
if
pooler_config
.
pooling_type
!=
"STEP"
:
if
pooler_config
.
tok_
pooling_type
!=
"STEP"
:
invalid_parameters
=
[]
for
k
in
step_pooling_parameters
:
if
getattr
(
self
,
k
,
None
)
is
not
None
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment