Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
84cf78ac
Unverified
Commit
84cf78ac
authored
Aug 12, 2025
by
wang.yuqi
Committed by
GitHub
Aug 11, 2025
Browse files
[Model] Pooling models default to using chunked prefill & prefix caching if supported. (#20930)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
16fb668b
Changes
31
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
59 additions
and
31 deletions
+59
-31
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+2
-2
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+8
-8
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+3
-1
vllm/model_executor/models/interfaces.py
vllm/model_executor/models/interfaces.py
+14
-0
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+2
-1
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+1
-3
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+4
-2
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+6
-10
vllm/model_executor/models/registry.py
vllm/model_executor/models/registry.py
+4
-2
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+3
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+12
-1
No files found.
vllm/model_executor/models/adapters.py
View file @
84cf78ac
...
...
@@ -182,8 +182,8 @@ def as_seq_cls_model(cls: _T) -> _T:
assert
pooler_config
is
not
None
pooling_type_str
=
pooler_config
.
pooling_type
pooling_type
=
(
PoolingType
.
LAST
if
pooling_type_str
is
None
els
e
PoolingType
[
pooling_type_str
]
)
assert
pooling_type_str
is
not
Non
e
pooling_type
=
PoolingType
[
pooling_type_str
]
self
.
pooler
=
DispatchPooler
({
"encode"
:
...
...
vllm/model_executor/models/bert.py
View file @
84cf78ac
...
...
@@ -28,7 +28,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.interfaces
import
(
SupportsCrossEncoding
,
SupportsQuant
,
default_pooling_type
)
from
.utils
import
AutoWeightsLoader
,
WeightsMapper
,
maybe_prefix
...
...
@@ -327,6 +328,7 @@ class BertOutput(nn.Module):
@
support_torch_compile
@
default_pooling_type
(
"CLS"
)
class
BertModel
(
nn
.
Module
,
SupportsQuant
):
is_pooling_model
=
True
...
...
@@ -401,6 +403,7 @@ class BertModel(nn.Module, SupportsQuant):
return
loaded_params
@
default_pooling_type
(
"ALL"
)
class
BertPoolingModel
(
BertModel
):
is_pooling_model
=
True
...
...
@@ -431,6 +434,7 @@ class BertPoolingModel(BertModel):
return
loaded_params
@
default_pooling_type
(
"CLS"
)
class
BertEmbeddingModel
(
nn
.
Module
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
...
...
@@ -486,13 +490,8 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def
_build_pooler
(
self
,
pooler_config
:
PoolerConfig
)
->
Pooler
:
return
DispatchPooler
({
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
,
default_pooling_type
=
PoolingType
.
CLS
,
),
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
})
...
...
@@ -541,6 +540,7 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
return
token_type_ids
@
default_pooling_type
(
"CLS"
)
class
BertForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
,
SupportsQuant
):
"""A model that uses Bert to provide embedding functionalities.
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
84cf78ac
...
...
@@ -27,7 +27,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsQuant
from
vllm.model_executor.models.interfaces
import
(
SupportsQuant
,
default_pooling_type
)
from
vllm.model_executor.models.utils
import
WeightsMapper
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
...
...
@@ -401,6 +402,7 @@ class BertWithRopeEncoder(nn.Module):
@
support_torch_compile
@
default_pooling_type
(
"CLS"
)
class
BertWithRope
(
nn
.
Module
,
SupportsQuant
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"model."
:
""
})
...
...
vllm/model_executor/models/interfaces.py
View file @
84cf78ac
...
...
@@ -641,6 +641,20 @@ def supports_cross_encoding(
return
is_pooling_model
(
model
)
and
_supports_cross_encoding
(
model
)
def
default_pooling_type
(
pooling_type
:
str
)
->
object
:
"""Set default_pooling_type decorator. """
def
func
(
model
:
object
):
model
.
default_pooling_type
=
pooling_type
return
model
return
func
def
get_default_pooling_type
(
model
:
Union
[
type
[
object
],
object
])
->
str
:
return
getattr
(
model
,
"default_pooling_type"
,
"LAST"
)
class
SupportsQuant
:
"""The interface required for all models that support quantization."""
...
...
vllm/model_executor/models/internlm2.py
View file @
84cf78ac
...
...
@@ -31,7 +31,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
default_pooling_type
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
...
...
@@ -401,6 +401,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
return
loaded_params
@
default_pooling_type
(
"ALL"
)
class
InternLM2ForRewardModel
(
InternLM2ForCausalLM
):
is_pooling_model
=
True
...
...
vllm/model_executor/models/jamba.py
View file @
84cf78ac
...
...
@@ -22,8 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.mamba.mamba_mixer
import
MambaMixer
from
vllm.model_executor.layers.mamba.mamba_utils
import
(
MambaStateShapeCalculator
)
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
Pooler
,
PoolingType
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
...
...
@@ -604,6 +603,5 @@ class JambaForSequenceClassification(JambaForCausalLM):
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
default_pooling_type
=
PoolingType
.
LAST
,
),
})
vllm/model_executor/models/modernbert.py
View file @
84cf78ac
...
...
@@ -26,7 +26,8 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
.interfaces
import
SupportsCrossEncoding
,
SupportsV0Only
from
.interfaces
import
(
SupportsCrossEncoding
,
SupportsV0Only
,
default_pooling_type
)
from
.utils
import
WeightsMapper
,
maybe_prefix
...
...
@@ -201,6 +202,7 @@ class ModernBertEncoderLayer(nn.Module):
@
support_torch_compile
@
default_pooling_type
(
"CLS"
)
class
ModernBertModel
(
nn
.
Module
):
hf_to_vllm_mapper
=
WeightsMapper
(
orig_to_new_prefix
=
{
"layers."
:
"encoder_layer.layers."
})
...
...
@@ -264,7 +266,6 @@ class ModernBertPooler(Pooler):
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
pooling_type
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
)
self
.
pooling_type
=
config
.
classifier_pooling
self
.
act
=
nn
.
GELU
()
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
,
...
...
@@ -294,6 +295,7 @@ class ModernBertPooler(Pooler):
return
pooled_output
@
default_pooling_type
(
"CLS"
)
class
ModernBertForSequenceClassification
(
nn
.
Module
,
SupportsV0Only
,
SupportsCrossEncoding
):
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
84cf78ac
...
...
@@ -15,11 +15,10 @@ from torch import nn
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
Pooler
,
PoolingType
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
default_pooling_type
from
.qwen2
import
Qwen2Model
from
.utils
import
AutoWeightsLoader
,
maybe_prefix
...
...
@@ -90,6 +89,7 @@ class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
return
loader
.
load_weights
(
weights
)
@
default_pooling_type
(
"ALL"
)
class
Qwen2ForRewardModel
(
Qwen2RewardBaseModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -103,6 +103,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
)},
)
@
default_pooling_type
(
"STEP"
)
class
Qwen2ForProcessRewardModel
(
Qwen2RewardBaseModel
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -112,10 +113,5 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
({
"encode"
:
Pooler
.
for_encode
(
pooler_config
,
default_pooling_type
=
PoolingType
.
STEP
,
)
})
self
.
pooler
=
DispatchPooler
(
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
)})
vllm/model_executor/models/registry.py
View file @
84cf78ac
...
...
@@ -25,8 +25,8 @@ from vllm.logger import init_logger
from
vllm.transformers_utils.dynamic_module
import
(
try_get_class_from_dynamic_module
)
from
.interfaces
import
(
has_inner_state
,
has_noops
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
from
.interfaces
import
(
get_default_pooling_type
,
has_inner_state
,
has_noops
,
is_attention_free
,
is_hybrid
,
supports_cross_encoding
,
supports_multimodal
,
supports_multimodal_raw_input
,
supports_pp
,
supports_transcription
,
supports_v0_only
)
from
.interfaces_base
import
is_pooling_model
,
is_text_generation_model
...
...
@@ -305,6 +305,7 @@ class _ModelInfo:
architecture
:
str
is_text_generation_model
:
bool
is_pooling_model
:
bool
default_pooling_type
:
str
supports_cross_encoding
:
bool
supports_multimodal
:
bool
supports_multimodal_raw_input
:
bool
...
...
@@ -323,6 +324,7 @@ class _ModelInfo:
architecture
=
model
.
__name__
,
is_text_generation_model
=
is_text_generation_model
(
model
),
is_pooling_model
=
is_pooling_model
(
model
),
default_pooling_type
=
get_default_pooling_type
(
model
),
supports_cross_encoding
=
supports_cross_encoding
(
model
),
supports_multimodal
=
supports_multimodal
(
model
),
supports_multimodal_raw_input
=
supports_multimodal_raw_input
(
model
),
...
...
vllm/model_executor/models/roberta.py
View file @
84cf78ac
...
...
@@ -23,7 +23,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper,
from
vllm.sequence
import
IntermediateTensors
from
.bert_with_rope
import
BertWithRope
,
JinaRobertaModel
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces
import
SupportsCrossEncoding
,
default_pooling_type
class
RobertaEmbedding
(
nn
.
Module
):
...
...
@@ -86,6 +86,7 @@ class RobertaClassificationHead(nn.Module):
return
x
@
default_pooling_type
(
"CLS"
)
class
RobertaEmbeddingModel
(
BertEmbeddingModel
):
"""A model that uses Roberta to provide embedding functionalities.
...
...
@@ -149,6 +150,7 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
return
loader
.
load_weights
(
weights_list
,
mapper
=
mapper
)
@
default_pooling_type
(
"CLS"
)
class
RobertaForSequenceClassification
(
nn
.
Module
,
SupportsCrossEncoding
):
"""A model that uses Roberta to provide embedding functionalities.
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
84cf78ac
...
...
@@ -1272,7 +1272,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if
not
is_pooling_model
(
model
):
return
[]
return
list
(
model
.
pooler
.
get_supported_tasks
())
supported_tasks
=
list
(
model
.
pooler
.
get_supported_tasks
())
if
(
self
.
scheduler_config
.
chunked_prefill_enabled
and
"encode"
in
supported_tasks
):
supported_tasks
.
remove
(
"encode"
)
logger
.
info_once
(
"Chunked prefill is not supported with "
"encode task which using ALL pooling. "
"Please turn off chunked prefill by "
"`--no-enable-chunked-prefill` before using it."
)
return
supported_tasks
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
tasks
=
list
[
SupportedTask
]()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment