Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d1b6fe00
Unverified
Commit
d1b6fe00
authored
Jan 08, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 08, 2026
Browse files
[Chore] Further cleanup pooler (#31951)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
04a49669
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
47 additions
and
62 deletions
+47
-62
tests/model_executor/test_model_load_with_params.py
tests/model_executor/test_model_load_with_params.py
+3
-8
tests/test_config.py
tests/test_config.py
+1
-2
vllm/config/pooler.py
vllm/config/pooler.py
+1
-2
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+38
-44
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+1
-2
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+1
-2
vllm/v1/outputs.py
vllm/v1/outputs.py
+2
-2
No files found.
tests/model_executor/test_model_load_with_params.py
View file @
d1b6fe00
...
...
@@ -5,12 +5,7 @@ import os
import
pytest
from
vllm.model_executor.layers.pooler
import
(
CLSPool
,
DispatchPooler
,
MeanPool
,
PoolingType
,
)
from
vllm.model_executor.layers.pooler
import
CLSPool
,
DispatchPooler
,
MeanPool
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
from
vllm.model_executor.models.roberta
import
RobertaEmbeddingModel
from
vllm.platforms
import
current_platform
...
...
@@ -50,7 +45,7 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
assert
model_config
.
encoder_config
[
"do_lower_case"
]
# asserts on the pooling config files
assert
model_config
.
pooler_config
.
pooling_type
==
PoolingType
.
CLS
.
name
assert
model_config
.
pooler_config
.
pooling_type
==
"CLS"
assert
model_config
.
pooler_config
.
normalize
# asserts on the tokenizer loaded
...
...
@@ -94,7 +89,7 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
assert
not
model_config
.
encoder_config
[
"do_lower_case"
]
# asserts on the pooling config files
assert
model_config
.
pooler_config
.
pooling_type
==
PoolingType
.
MEAN
.
name
assert
model_config
.
pooler_config
.
pooling_type
==
"MEAN"
assert
model_config
.
pooler_config
.
normalize
# asserts on the tokenizer loaded
...
...
tests/test_config.py
View file @
d1b6fe00
...
...
@@ -25,7 +25,6 @@ from vllm.config.vllm import (
OPTIMIZATION_LEVEL_TO_CONFIG
,
OptimizationLevel
,
)
from
vllm.model_executor.layers.pooler
import
PoolingType
from
vllm.platforms
import
current_platform
...
...
@@ -162,7 +161,7 @@ def test_get_pooling_config():
assert
model_config
.
pooler_config
is
not
None
assert
model_config
.
pooler_config
.
normalize
assert
model_config
.
pooler_config
.
pooling_type
==
PoolingType
.
MEAN
.
name
assert
model_config
.
pooler_config
.
pooling_type
==
"MEAN"
@
pytest
.
mark
.
skipif
(
...
...
vllm/config/pooler.py
View file @
d1b6fe00
...
...
@@ -21,8 +21,7 @@ class PoolerConfig:
pooling_type
:
PoolingTypeStr
|
None
=
None
"""
The pooling method of the pooling model. This should be a key in
[`vllm.model_executor.layers.pooler.PoolingType`][].
The pooling method of the pooling model.
"""
## for embeddings models
...
...
vllm/model_executor/layers/pooler.py
View file @
d1b6fe00
...
...
@@ -3,7 +3,6 @@
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
,
Mapping
,
Set
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
itertools
import
groupby
from
typing
import
TypeAlias
,
TypeVar
...
...
@@ -12,13 +11,14 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
,
get_current_vllm_config
from
vllm.config
import
ModelConfig
,
get_current_vllm_config
from
vllm.config.pooler
import
PoolerConfig
,
PoolingTypeStr
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.v1.outputs
import
PoolerOutput
,
TokenPoolerOutput
,
Token
s
PoolerOutput
from
vllm.v1.outputs
import
PoolerOutput
,
TokenPoolerOutput
,
Token
wise
PoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
logger
=
init_logger
(
__name__
)
...
...
@@ -31,27 +31,17 @@ ClassifierFn = Callable[[torch.Tensor], torch.Tensor]
TokenPoolingMethodOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
Token
s
PoolingMethodOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
Token
s
PoolingMethodOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
PoolingMethodOutput
:
TypeAlias
=
TokenPoolingMethodOutput
|
Token
s
PoolingMethodOutput
Token
wise
PoolingMethodOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
Token
wise
PoolingMethodOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
PoolingMethodOutput
:
TypeAlias
=
TokenPoolingMethodOutput
|
Token
wise
PoolingMethodOutput
TokenPoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokensPoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
None
class
PoolingType
(
IntEnum
):
"""Enumeration for different types of pooling methods."""
LAST
=
0
ALL
=
1
CLS
=
2
STEP
=
3
MEAN
=
4
TokenwisePoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
None
@
dataclass
(
frozen
=
True
)
class
ResolvedPoolingConfig
:
pooling_type
:
PoolingType
pooling_type
:
PoolingType
Str
task
:
PoolingTask
@
classmethod
...
...
@@ -61,7 +51,7 @@ class ResolvedPoolingConfig:
pooler_config
:
PoolerConfig
,
)
->
"ResolvedPoolingConfig"
:
assert
pooler_config
.
pooling_type
is
not
None
return
cls
(
task
=
task
,
pooling_type
=
PoolingType
[
pooler_config
.
pooling_type
]
)
return
cls
(
task
=
task
,
pooling_type
=
pooler_config
.
pooling_type
)
@
dataclass
(
frozen
=
True
)
...
...
@@ -112,17 +102,22 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
class
PoolingMethod
(
nn
.
Module
,
ABC
):
@
staticmethod
def
from_pooling_type
(
pooling_type
:
PoolingType
)
->
"PoolingMethod"
:
if
pooling_type
==
PoolingType
.
LAST
:
def
from_pooling_type
(
pooling_type
:
PoolingType
Str
)
->
"PoolingMethod"
:
if
pooling_type
==
"
LAST
"
:
return
LastPool
()
if
pooling_type
==
PoolingType
.
ALL
:
if
pooling_type
==
"
ALL
"
:
return
AllPool
()
if
pooling_type
==
PoolingType
.
CLS
:
if
pooling_type
==
"
CLS
"
:
return
CLSPool
()
if
pooling_type
==
PoolingType
.
MEAN
:
if
pooling_type
==
"
MEAN
"
:
return
MeanPool
()
if
pooling_type
==
"STEP"
:
raise
ValueError
(
"'STEP' pooling is handled by StepPooler "
"and is not a standalone PoolingMethod."
)
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
}
"
)
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
!
r
}
"
)
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
...
...
@@ -186,13 +181,12 @@ class AllPool(PoolingMethod):
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
s
PoolingMethodOutput
:
)
->
Token
wise
PoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
is_finished
=
pooling_cursor
.
is_finished
()
hidden_states_lst
=
list
(
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
())
hidden_states_all
=
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
()
)
hidden_states_lst
=
[
hidden_states_
lst
[
i
]
for
i
in
pooling_cursor
.
index
]
hidden_states_lst
=
[
hidden_states_
all
[
i
]
for
i
in
pooling_cursor
.
index
]
if
not
self
.
enable_chunked_prefill
:
return
hidden_states_lst
...
...
@@ -206,7 +200,7 @@ class AllPool(PoolingMethod):
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list
=
list
[
torch
.
Tensor
|
None
]()
for
p
,
finished
in
zip
(
pooling_states
,
is_finished
):
for
p
,
finished
in
zip
(
pooling_states
,
pooling_cursor
.
is_finished
()
):
if
finished
:
hidden_states_cache
=
p
.
hidden_states_cache
if
len
(
hidden_states_cache
)
==
1
:
...
...
@@ -620,19 +614,19 @@ class ClassifierPooler(Pooler):
return
scores
class
Token
s
PoolerHead
(
nn
.
Module
,
ABC
):
class
Token
wise
PoolerHead
(
nn
.
Module
,
ABC
):
"""Applicable to pooling strategies that output multiple tokens."""
@
abstractmethod
def
forward
(
self
,
pooled_data
:
Token
s
PoolingMethodOutputItem
,
pooled_data
:
Token
wise
PoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
Token
s
PoolerHeadOutput
:
)
->
Token
wise
PoolerHeadOutput
:
raise
NotImplementedError
class
TokenEmbeddingPoolerHead
(
Token
s
PoolerHead
):
class
TokenEmbeddingPoolerHead
(
Token
wise
PoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
...
...
@@ -647,9 +641,9 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
def
forward
(
self
,
pooled_data
:
Token
s
PoolingMethodOutputItem
,
pooled_data
:
Token
wise
PoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
Token
s
PoolerHeadOutput
:
)
->
Token
wise
PoolerHeadOutput
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
...
...
@@ -673,7 +667,7 @@ class TokenEmbeddingPoolerHead(TokensPoolerHead):
return
pooled_data
class
TokenClassifierPoolerHead
(
Token
s
PoolerHead
):
class
TokenClassifierPoolerHead
(
Token
wise
PoolerHead
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
,
...
...
@@ -695,9 +689,9 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
def
forward
(
self
,
pooled_data
:
Token
s
PoolingMethodOutputItem
,
pooled_data
:
Token
wise
PoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
Token
s
PoolerHeadOutput
:
)
->
Token
wise
PoolerHeadOutput
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
...
...
@@ -722,7 +716,7 @@ class TokenClassifierPoolerHead(TokensPoolerHead):
class
AllPooler
(
Pooler
):
def
__init__
(
self
,
head
:
Token
s
PoolerHead
)
->
None
:
def
__init__
(
self
,
head
:
Token
wise
PoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
...
...
@@ -735,7 +729,7 @@ class AllPooler(Pooler):
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
s
PoolerOutput
:
)
->
Token
wise
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
...
...
@@ -744,7 +738,7 @@ class AllPooler(Pooler):
class
StepPooler
(
Pooler
):
def
__init__
(
self
,
head
:
Token
s
PoolerHead
)
->
None
:
def
__init__
(
self
,
head
:
Token
wise
PoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
...
...
@@ -790,7 +784,7 @@ class StepPooler(Pooler):
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
s
PoolerOutput
:
)
->
Token
wise
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
...
...
vllm/model_executor/models/bert.py
View file @
d1b6fe00
...
...
@@ -23,7 +23,6 @@ from vllm.model_executor.layers.pooler import (
Pooler
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingType
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
...
...
@@ -90,7 +89,7 @@ class BertPooler(Pooler):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
PoolingType
.
CLS
)
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
"
CLS
"
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
...
...
vllm/model_executor/models/modernbert.py
View file @
d1b6fe00
...
...
@@ -18,7 +18,6 @@ from vllm.model_executor.layers.pooler import (
Pooler
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingType
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
...
...
@@ -287,7 +286,7 @@ class ModernBertPooler(Pooler):
def
__init__
(
self
,
config
:
ModernBertConfig
):
super
().
__init__
()
pooling_type
=
PoolingType
[
config
.
classifier_pooling
.
upper
()
]
pooling_type
=
config
.
classifier_pooling
.
upper
()
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
pooling_type
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
classifier_bias
...
...
vllm/v1/outputs.py
View file @
d1b6fe00
...
...
@@ -92,8 +92,8 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
TokenPoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
Token
s
PoolerOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
PoolerOutput
:
TypeAlias
=
TokenPoolerOutput
|
Token
s
PoolerOutput
Token
wise
PoolerOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
PoolerOutput
:
TypeAlias
=
TokenPoolerOutput
|
Token
wise
PoolerOutput
@
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment