Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b7036c87
Unverified
Commit
b7036c87
authored
Jan 08, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 08, 2026
Browse files
[Refactor] Clean up pooler modules (#31897)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
cc6dafae
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
167 additions
and
120 deletions
+167
-120
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+102
-81
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+19
-14
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+17
-9
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+18
-13
vllm/v1/outputs.py
vllm/v1/outputs.py
+4
-2
vllm/v1/pool/metadata.py
vllm/v1/pool/metadata.py
+6
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
No files found.
vllm/model_executor/layers/pooler.py
View file @
b7036c87
...
@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set
...
@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
enum
import
IntEnum
from
itertools
import
groupby
from
itertools
import
groupby
from
typing
import
TypeVar
from
typing
import
TypeAlias
,
TypeVar
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector
...
@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.v1.outputs
import
PoolerOutput
from
vllm.v1.outputs
import
PoolerOutput
,
TokenPoolerOutput
,
TokensPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingCursor
,
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -30,6 +30,15 @@ PoolingFn = Callable[
...
@@ -30,6 +30,15 @@ PoolingFn = Callable[
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
TokenPoolingMethodOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokensPoolingMethodOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
TokensPoolingMethodOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
PoolingMethodOutput
:
TypeAlias
=
TokenPoolingMethodOutput
|
TokensPoolingMethodOutput
TokenPoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokensPoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
None
class
PoolingType
(
IntEnum
):
class
PoolingType
(
IntEnum
):
"""Enumeration for different types of pooling methods."""
"""Enumeration for different types of pooling methods."""
...
@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC):
...
@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC):
return
PoolingParamsUpdate
()
return
PoolingParamsUpdate
()
@
abstractmethod
@
abstractmethod
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
,
)
->
PoolerOutput
:
raise
NotImplementedError
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
PoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
pooling_cursor
raise
NotImplementedError
return
self
.
forward_all
(
hidden_states
,
pooling_cursor
)
class
CLSPool
(
PoolingMethod
):
class
CLSPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
_all
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with CLS pooling"
"partial prefill not supported with CLS pooling"
)
)
...
@@ -159,11 +161,12 @@ class LastPool(PoolingMethod):
...
@@ -159,11 +161,12 @@ class LastPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
_all
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
...
@@ -179,19 +182,12 @@ class AllPool(PoolingMethod):
...
@@ -179,19 +182,12 @@ class AllPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
return
{
"token_embed"
,
"token_classify"
}
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
)
->
PoolerOutput
:
raise
NotImplementedError
(
"forward_all is not implemented for AllPool. Use forward instead."
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
Pooler
Output
:
)
->
TokensPoolingMethod
Output
:
pooling_cursor
=
pooling_metadata
.
pooling_cursor
pooling_cursor
=
pooling_metadata
.
get_
pooling_cursor
()
is_finished
=
pooling_cursor
.
is_finished
()
is_finished
=
pooling_cursor
.
is_finished
()
hidden_states_lst
=
list
(
hidden_states_lst
=
list
(
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
())
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
())
...
@@ -209,7 +205,7 @@ class AllPool(PoolingMethod):
...
@@ -209,7 +205,7 @@ class AllPool(PoolingMethod):
p
.
hidden_states_cache
.
append
(
hs_chunk
)
p
.
hidden_states_cache
.
append
(
hs_chunk
)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list
:
PoolerOutput
=
[]
output_list
=
list
[
torch
.
Tensor
|
None
]()
for
p
,
finished
in
zip
(
pooling_states
,
is_finished
):
for
p
,
finished
in
zip
(
pooling_states
,
is_finished
):
if
finished
:
if
finished
:
hidden_states_cache
=
p
.
hidden_states_cache
hidden_states_cache
=
p
.
hidden_states_cache
...
@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod):
...
@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
_all
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with MEAN pooling"
"partial prefill not supported with MEAN pooling"
)
)
...
@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC):
...
@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC):
@
abstractmethod
@
abstractmethod
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
PoolerOutput
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -422,41 +419,42 @@ class DummyPooler(Pooler):
...
@@ -422,41 +419,42 @@ class DummyPooler(Pooler):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
PoolerOutput
:
return
hidden_states
return
hidden_states
class
PoolerHead
(
nn
.
Module
):
class
TokenPoolerHead
(
nn
.
Module
,
ABC
):
def
__init__
(
self
,
activation
:
PoolerActivation
)
->
None
:
"""Applicable to pooling strategies that output one token."""
super
().
__init__
()
self
.
activation
=
activation
@
abstractmethod
def
forward
(
def
forward
(
self
,
self
,
pooled_data
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
Pooler
Head
Output
:
r
eturn
self
.
activation
(
pooled_data
)
r
aise
NotImplementedError
class
EmbeddingPoolerHead
(
PoolerHead
):
class
EmbeddingPoolerHead
(
Token
PoolerHead
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerNormalize
()
)
super
().
__init__
()
# Load ST projector if available
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
projector
:
nn
.
Module
|
None
=
(
self
.
projector
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
)
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
forward
(
def
forward
(
self
,
self
,
pooled_data
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
Pooler
Head
Output
:
if
isinstance
(
pooled_data
,
list
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
# pooled_data shape: [batchsize, hidden_dimension]
...
@@ -509,7 +507,7 @@ class SimplePooler(Pooler):
...
@@ -509,7 +507,7 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`.
3. Returns structured results as `PoolerOutput`.
"""
"""
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
PoolerHead
)
->
None
:
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
Token
PoolerHead
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pooling
=
pooling
self
.
pooling
=
pooling
...
@@ -523,9 +521,9 @@ class SimplePooler(Pooler):
...
@@ -523,9 +521,9 @@ class SimplePooler(Pooler):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
Pooler
Head
Output
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
return
pooled_data
...
@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler):
...
@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
torch
.
stack
(
pooled_data
)
...
@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler):
...
@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler):
return
scores
return
scores
class
TokenEmbeddingPoolerHead
(
EmbeddingPoolerHead
):
class
TokensPoolerHead
(
nn
.
Module
,
ABC
):
"""Applicable to pooling strategies that output multiple tokens."""
@
abstractmethod
def
forward
(
def
forward
(
self
,
pooled_data
:
torch
.
Tensor
|
None
,
pooling_param
:
PoolingParams
self
,
)
->
PoolerOutput
:
pooled_data
:
TokensPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokensPoolerHeadOutput
:
raise
NotImplementedError
class
TokenEmbeddingPoolerHead
(
TokensPoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
forward
(
self
,
pooled_data
:
TokensPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokensPoolerHeadOutput
:
# for unfinished chunked prefill
# for unfinished chunked prefill
if
pooled_data
is
None
:
if
pooled_data
is
None
:
return
None
return
None
...
@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
...
@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
return
pooled_data
return
pooled_data
class
TokenClassifierPoolerHead
(
nn
.
Module
):
class
TokenClassifierPoolerHead
(
TokensPoolerHead
):
def
__init__
(
def
__init__
(
self
,
self
,
classifier
:
ClassifierFn
|
None
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
classifier
=
classifier
self
.
classifier
=
classifier
self
.
act_fn
=
ClassifierPooler
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
self
.
logit_bias
:
float
|
None
=
(
self
.
logit_bias
:
float
|
None
=
(
vllm_config
.
model_config
.
pooler_config
.
logit_bias
vllm_config
.
model_config
.
pooler_config
.
logit_bias
)
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
self
.
activation
=
ClassifierPooler
.
resolve_act_fn
(
return
{
"token_classify"
}
vllm_config
.
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
None
,
pooled_data
:
TokensPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
pooling_param
:
PoolingParams
,
)
->
PoolerOutput
:
)
->
Tokens
Pooler
Head
Output
:
# for unfinished chunked prefill
# for unfinished chunked prefill
if
hidden_states
is
None
:
if
pooled_data
is
None
:
return
None
return
None
hidden_states
=
hidden_states
.
to
(
self
.
head_dtype
)
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
# hidden_states shape: [n_token, hidden_size]
if
self
.
classifier
is
not
None
:
if
self
.
classifier
is
not
None
:
scores
=
self
.
classifier
(
hidden_states
)
scores
=
self
.
classifier
(
pooled_data
)
else
:
else
:
scores
=
hidden_states
scores
=
pooled_data
# scores shape: [n_token, num_labels]
# scores shape: [n_token, num_labels]
if
self
.
logit_bias
is
not
None
:
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
scores
-=
self
.
logit_bias
if
pooling_param
.
use_activation
:
if
pooling_param
.
use_activation
:
scores
=
self
.
act
_f
n
(
scores
)
scores
=
self
.
act
ivatio
n
(
scores
)
# scores shape: [n_token, num_labels]
# scores shape: [n_token, num_labels]
return
scores
return
scores
class
AllPooler
(
Pooler
):
class
AllPooler
(
Pooler
):
def
__init__
(
self
,
head
:
nn
.
Module
|
PoolerHead
)
->
None
:
def
__init__
(
self
,
head
:
Tokens
PoolerHead
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
pooling
=
AllPool
()
...
@@ -712,17 +735,16 @@ class AllPooler(Pooler):
...
@@ -712,17 +735,16 @@ class AllPooler(Pooler):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Tokens
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
pooled_data
=
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
return
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
return
pooled_data
class
StepPooler
(
Pooler
):
class
StepPooler
(
Pooler
):
def
__init__
(
self
,
head
:
nn
.
Module
|
PoolerHead
)
->
None
:
def
__init__
(
self
,
head
:
Tokens
PoolerHead
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
pooling
=
AllPool
()
...
@@ -730,14 +752,14 @@ class StepPooler(Pooler):
...
@@ -730,14 +752,14 @@ class StepPooler(Pooler):
def
extract_states
(
def
extract_states
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
list
[
torch
.
Tensor
|
None
]
:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
pooling_metadata
.
get_prompt_token_ids
()
prompt_token_ids
=
pooling_metadata
.
get_prompt_token_ids
()
pooling_params
=
pooling_metadata
.
pooling_params
pooling_params
=
pooling_metadata
.
pooling_params
pooled_data
:
PoolerOutput
=
[]
pooled_data
=
list
[
torch
.
Tensor
|
None
]()
for
data
,
token_id
,
pooling_param
in
zip
(
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
pooled_data_lst
,
prompt_token_ids
,
pooling_params
):
):
...
@@ -766,15 +788,14 @@ class StepPooler(Pooler):
...
@@ -766,15 +788,14 @@ class StepPooler(Pooler):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Tokens
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
pooled_data
=
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
return
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
return
pooled_data
class
DispatchPooler
(
Pooler
):
class
DispatchPooler
(
Pooler
):
...
@@ -800,12 +821,12 @@ class DispatchPooler(Pooler):
...
@@ -800,12 +821,12 @@ class DispatchPooler(Pooler):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
poolers_by_task
=
self
.
poolers_by_task
outputs
=
list
[
torch
.
Tensor
]()
outputs
=
list
[
torch
.
Tensor
|
None
]()
offset
=
0
offset
=
0
for
task
,
group
in
groupby
(
pooling_metadata
.
tasks
):
for
task
,
group
in
groupby
(
pooling_metadata
.
tasks
):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
...
...
vllm/model_executor/models/bert.py
View file @
b7036c87
...
@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
...
@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingParamsUpdate
,
PoolingType
,
PoolingType
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
...
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
...
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
_head
(
self
,
pooled_output
:
torch
.
Tensor
):
def
head
(
pooled_output
=
self
.
dense
(
pooled_output
)
pooled_output
=
self
.
activation
(
pooled_output
)
return
pooled_output
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
)
->
TokenPoolerHeadOutput
:
pooled_output
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
if
isinstance
(
pooled_output
,
list
):
pooled_data
=
self
.
dense
(
pooled_data
)
pooled_output
=
[
self
.
_head
(
output
)
for
output
in
pooled_output
]
pooled_data
=
self
.
activation
(
pooled_data
)
else
:
return
pooled_data
pooled_output
=
self
.
_head
(
pooled_output
)
return
pooled_output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
...
...
vllm/model_executor/models/gritlm.py
View file @
b7036c87
...
@@ -4,21 +4,22 @@ from collections.abc import Set
...
@@ -4,21 +4,22 @@ from collections.abc import Set
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
DispatchPooler
,
Pooler
,
Pooler
,
PoolerHead
,
PoolerNormalize
,
PoolerNormalize
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingParamsUpdate
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
)
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.v1.outputs
import
PoolerOutput
from
vllm.v1.outputs
import
Token
PoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces_base
import
default_pooling_type
from
.interfaces_base
import
default_pooling_type
...
@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type
...
@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
GritLMMeanPool
(
nn
.
Module
):
class
GritLMMeanPool
(
PoolingMethod
):
"""As `MeanPool`, but only includes non-instruction tokens."""
"""As `MeanPool`, but only includes non-instruction tokens."""
def
__init__
(
self
,
model_config
:
ModelConfig
):
def
__init__
(
self
,
model_config
:
ModelConfig
):
...
@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module):
...
@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module):
return
instruction_len
return
instruction_len
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
,
"embed"
}
return
{
"embed"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
]
|
torch
.
Tensor
:
)
->
TokenPoolingMethodOutput
:
prompt_lens
=
pooling_metadata
.
prompt_lens
prompt_lens
=
pooling_metadata
.
prompt_lens
instr_lens
=
torch
.
tensor
(
instr_lens
=
torch
.
tensor
(
[
[
...
@@ -178,7 +179,7 @@ class GritLMPooler(Pooler):
...
@@ -178,7 +179,7 @@ class GritLMPooler(Pooler):
super
().
__init__
()
super
().
__init__
()
self
.
pooling
=
GritLMMeanPool
(
model_config
)
self
.
pooling
=
GritLMMeanPool
(
model_config
)
self
.
head
=
PoolerHead
(
PoolerNormalize
()
)
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
return
self
.
pooling
.
get_supported_tasks
()
...
@@ -186,11 +187,18 @@ class GritLMPooler(Pooler):
...
@@ -186,11 +187,18 @@ class GritLMPooler(Pooler):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
head
(
self
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerHeadOutput
:
return
self
.
activation
(
pooled_data
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
return
pooled_data
...
...
vllm/model_executor/models/modernbert.py
View file @
b7036c87
...
@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import (
...
@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingParamsUpdate
,
PoolingType
,
PoolingType
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
from
.interfaces
import
SupportsCrossEncoding
...
@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler):
...
@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
_head
(
self
,
pooled_output
:
torch
.
Tensor
):
def
head
(
pooled_output
=
pooled_output
.
to
(
self
.
dense
.
weight
.
dtype
)
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_output
)))
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
)
->
TokenPoolerHeadOutput
:
pooled_output
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
if
isinstance
(
pooled_output
,
list
):
pooled_data
=
pooled_data
.
to
(
self
.
dense
.
weight
.
dtype
)
pooled_output
=
[
self
.
_head
(
output
)
for
output
in
pooled_output
]
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_data
)))
else
:
pooled_output
=
self
.
_head
(
pooled_output
)
return
pooled_output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
"CLS"
)
...
...
vllm/v1/outputs.py
View file @
b7036c87
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
NamedTuple
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
TypeAlias
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -91,7 +91,9 @@ class LogprobsTensors(NamedTuple):
...
@@ -91,7 +91,9 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
# The shape of each element depends on the pooler used
PoolerOutput
=
list
[
torch
.
Tensor
|
None
]
|
torch
.
Tensor
|
None
TokenPoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokensPoolerOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
PoolerOutput
:
TypeAlias
=
TokenPoolerOutput
|
TokensPoolerOutput
@
dataclass
@
dataclass
...
...
vllm/v1/pool/metadata.py
View file @
b7036c87
...
@@ -90,6 +90,12 @@ class PoolingMetadata:
...
@@ -90,6 +90,12 @@ class PoolingMetadata:
return
[
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
self
.
prompt_lens
)]
return
[
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
self
.
prompt_lens
)]
def
get_pooling_cursor
(
self
)
->
PoolingCursor
:
pooling_cursor
=
self
.
pooling_cursor
assert
pooling_cursor
is
not
None
,
"Should call `build_pooling_cursor` first"
return
pooling_cursor
def
build_pooling_cursor
(
def
build_pooling_cursor
(
self
,
self
,
num_scheduled_tokens_np
:
np
.
ndarray
,
num_scheduled_tokens_np
:
np
.
ndarray
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b7036c87
...
@@ -4680,7 +4680,7 @@ class GPUModelRunner(
...
@@ -4680,7 +4680,7 @@ class GPUModelRunner(
for
task
in
supported_pooling_tasks
:
for
task
in
supported_pooling_tasks
:
# Run a full batch with each task to ensure none of them OOMs
# Run a full batch with each task to ensure none of them OOMs
output
=
self
.
_dummy_pooler_run_task
(
hidden_states
,
task
)
output
=
self
.
_dummy_pooler_run_task
(
hidden_states
,
task
)
output_size
[
task
]
=
sum
(
o
.
nbytes
for
o
in
output
)
output_size
[
task
]
=
sum
(
o
.
nbytes
for
o
in
output
if
o
is
not
None
)
del
output
# Allow GC
del
output
# Allow GC
max_task
=
max
(
output_size
.
items
(),
key
=
lambda
x
:
x
[
1
])[
0
]
max_task
=
max
(
output_size
.
items
(),
key
=
lambda
x
:
x
[
1
])[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment