Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b7036c87
Unverified
Commit
b7036c87
authored
Jan 08, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 08, 2026
Browse files
[Refactor] Clean up pooler modules (#31897)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
cc6dafae
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
167 additions
and
120 deletions
+167
-120
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+102
-81
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+19
-14
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+17
-9
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+18
-13
vllm/v1/outputs.py
vllm/v1/outputs.py
+4
-2
vllm/v1/pool/metadata.py
vllm/v1/pool/metadata.py
+6
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
No files found.
vllm/model_executor/layers/pooler.py
View file @
b7036c87
...
...
@@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Set
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
itertools
import
groupby
from
typing
import
TypeVar
from
typing
import
TypeAlias
,
TypeVar
import
torch
import
torch.nn
as
nn
...
...
@@ -18,8 +18,8 @@ from vllm.model_executor.models.adapters import _load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.v1.outputs
import
PoolerOutput
from
vllm.v1.pool.metadata
import
PoolingCursor
,
PoolingMetadata
from
vllm.v1.outputs
import
PoolerOutput
,
TokenPoolerOutput
,
TokensPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
logger
=
init_logger
(
__name__
)
...
...
@@ -30,6 +30,15 @@ PoolingFn = Callable[
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
TokenPoolingMethodOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokensPoolingMethodOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
TokensPoolingMethodOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
PoolingMethodOutput
:
TypeAlias
=
TokenPoolingMethodOutput
|
TokensPoolingMethodOutput
TokenPoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokensPoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
None
class
PoolingType
(
IntEnum
):
"""Enumeration for different types of pooling methods."""
...
...
@@ -123,31 +132,24 @@ class PoolingMethod(nn.Module, ABC):
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
,
)
->
PoolerOutput
:
raise
NotImplementedError
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooling_cursor
=
pooling_metadata
.
pooling_cursor
return
self
.
forward_all
(
hidden_states
,
pooling_cursor
)
)
->
PoolingMethodOutput
:
raise
NotImplementedError
class
CLSPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
_all
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
,
)
->
PoolerOutput
:
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with CLS pooling"
)
...
...
@@ -159,11 +161,12 @@ class LastPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
_all
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
,
)
->
PoolerOutput
:
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
...
...
@@ -179,19 +182,12 @@ class AllPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
forward_all
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
)
->
PoolerOutput
:
raise
NotImplementedError
(
"forward_all is not implemented for AllPool. Use forward instead."
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
Pooler
Output
:
pooling_cursor
=
pooling_metadata
.
pooling_cursor
)
->
TokensPoolingMethod
Output
:
pooling_cursor
=
pooling_metadata
.
get_
pooling_cursor
()
is_finished
=
pooling_cursor
.
is_finished
()
hidden_states_lst
=
list
(
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
())
...
...
@@ -209,7 +205,7 @@ class AllPool(PoolingMethod):
p
.
hidden_states_cache
.
append
(
hs_chunk
)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list
:
PoolerOutput
=
[]
output_list
=
list
[
torch
.
Tensor
|
None
]()
for
p
,
finished
in
zip
(
pooling_states
,
is_finished
):
if
finished
:
hidden_states_cache
=
p
.
hidden_states_cache
...
...
@@ -228,11 +224,12 @@ class MeanPool(PoolingMethod):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
_all
(
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_cursor
:
PoolingCursor
,
)
->
PoolerOutput
:
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with MEAN pooling"
)
...
...
@@ -410,7 +407,7 @@ class Pooler(nn.Module, ABC):
@
abstractmethod
def
forward
(
self
,
hidden_states
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
...
...
@@ -422,41 +419,42 @@ class DummyPooler(Pooler):
def
forward
(
self
,
hidden_states
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
hidden_states
class
PoolerHead
(
nn
.
Module
):
def
__init__
(
self
,
activation
:
PoolerActivation
)
->
None
:
super
().
__init__
()
self
.
activation
=
activation
class
TokenPoolerHead
(
nn
.
Module
,
ABC
):
"""Applicable to pooling strategies that output one token."""
@
abstractmethod
def
forward
(
self
,
pooled_data
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
r
eturn
self
.
activation
(
pooled_data
)
)
->
Token
Pooler
Head
Output
:
r
aise
NotImplementedError
class
EmbeddingPoolerHead
(
PoolerHead
):
class
EmbeddingPoolerHead
(
Token
PoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerNormalize
()
)
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
self
.
projector
:
nn
.
Module
|
None
=
(
self
.
projector
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
forward
(
self
,
pooled_data
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
Pooler
Head
Output
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
...
...
@@ -509,7 +507,7 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`.
"""
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
PoolerHead
)
->
None
:
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
Token
PoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
...
...
@@ -523,9 +521,9 @@ class SimplePooler(Pooler):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
Pooler
Head
Output
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
...
...
@@ -591,9 +589,9 @@ class ClassifierPooler(Pooler):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
...
...
@@ -622,10 +620,36 @@ class ClassifierPooler(Pooler):
return
scores
class
TokenEmbeddingPoolerHead
(
EmbeddingPoolerHead
):
class
TokensPoolerHead
(
nn
.
Module
,
ABC
):
"""Applicable to pooling strategies that output multiple tokens."""
@
abstractmethod
def
forward
(
self
,
pooled_data
:
torch
.
Tensor
|
None
,
pooling_param
:
PoolingParams
)
->
PoolerOutput
:
self
,
pooled_data
:
TokensPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokensPoolerHeadOutput
:
raise
NotImplementedError
class
TokenEmbeddingPoolerHead
(
TokensPoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
forward
(
self
,
pooled_data
:
TokensPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokensPoolerHeadOutput
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
...
...
@@ -649,57 +673,56 @@ class TokenEmbeddingPoolerHead(EmbeddingPoolerHead):
return
pooled_data
class
TokenClassifierPoolerHead
(
nn
.
Module
):
class
TokenClassifierPoolerHead
(
TokensPoolerHead
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
self
.
classifier
=
classifier
self
.
act_fn
=
ClassifierPooler
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
self
.
logit_bias
:
float
|
None
=
(
vllm_config
.
model_config
.
pooler_config
.
logit_bias
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_classify"
}
self
.
activation
=
ClassifierPooler
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
|
None
,
pooled_data
:
TokensPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
PoolerOutput
:
)
->
Tokens
Pooler
Head
Output
:
# for unfinished chunked prefill
if
hidden_states
is
None
:
if
pooled_data
is
None
:
return
None
hidden_states
=
hidden_states
.
to
(
self
.
head_dtype
)
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
if
self
.
classifier
is
not
None
:
scores
=
self
.
classifier
(
hidden_states
)
scores
=
self
.
classifier
(
pooled_data
)
else
:
scores
=
hidden_states
scores
=
pooled_data
# scores shape: [n_token, num_labels]
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
if
pooling_param
.
use_activation
:
scores
=
self
.
act
_f
n
(
scores
)
scores
=
self
.
act
ivatio
n
(
scores
)
# scores shape: [n_token, num_labels]
return
scores
class
AllPooler
(
Pooler
):
def
__init__
(
self
,
head
:
nn
.
Module
|
PoolerHead
)
->
None
:
def
__init__
(
self
,
head
:
Tokens
PoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
...
...
@@ -712,17 +735,16 @@ class AllPooler(Pooler):
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Tokens
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
pooled_data
=
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
return
pooled_data
return
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
class
StepPooler
(
Pooler
):
def
__init__
(
self
,
head
:
nn
.
Module
|
PoolerHead
)
->
None
:
def
__init__
(
self
,
head
:
Tokens
PoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
...
...
@@ -730,14 +752,14 @@ class StepPooler(Pooler):
def
extract_states
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
list
[
torch
.
Tensor
|
None
]
:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
pooling_metadata
.
get_prompt_token_ids
()
pooling_params
=
pooling_metadata
.
pooling_params
pooled_data
:
PoolerOutput
=
[]
pooled_data
=
list
[
torch
.
Tensor
|
None
]()
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
):
...
...
@@ -766,15 +788,14 @@ class StepPooler(Pooler):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Tokens
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
pooled_data
=
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
return
pooled_data
return
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
class
DispatchPooler
(
Pooler
):
...
...
@@ -800,12 +821,12 @@ class DispatchPooler(Pooler):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
outputs
=
list
[
torch
.
Tensor
]()
outputs
=
list
[
torch
.
Tensor
|
None
]()
offset
=
0
for
task
,
group
in
groupby
(
pooling_metadata
.
tasks
):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
...
...
vllm/model_executor/models/bert.py
View file @
b7036c87
...
...
@@ -24,11 +24,14 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod
,
PoolingParamsUpdate
,
PoolingType
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
...
...
@@ -97,24 +100,26 @@ class BertPooler(Pooler):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
_head
(
self
,
pooled_output
:
torch
.
Tensor
):
pooled_output
=
self
.
dense
(
pooled_output
)
pooled_output
=
self
.
activation
(
pooled_output
)
return
pooled_output
def
forward
(
def
head
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
pooled_output
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
)
->
TokenPoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
if
isinstance
(
pooled_output
,
list
):
pooled_output
=
[
self
.
_head
(
output
)
for
output
in
pooled_output
]
else
:
pooled_output
=
self
.
_head
(
pooled_output
)
pooled_data
=
self
.
dense
(
pooled_data
)
pooled_data
=
self
.
activation
(
pooled_data
)
return
pooled_data
return
pooled_output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
class
BertEncoder
(
nn
.
Module
):
...
...
vllm/model_executor/models/gritlm.py
View file @
b7036c87
...
...
@@ -4,21 +4,22 @@ from collections.abc import Set
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
Pooler
,
PoolerHead
,
PoolerNormalize
,
PoolingMethod
,
PoolingParamsUpdate
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
from
vllm.model_executor.models.llama
import
LlamaForCausalLM
from
vllm.tasks
import
PoolingTask
from
vllm.tokenizers
import
cached_tokenizer_from_config
from
vllm.v1.outputs
import
PoolerOutput
from
vllm.v1.outputs
import
Token
PoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces_base
import
default_pooling_type
...
...
@@ -26,7 +27,7 @@ from .interfaces_base import default_pooling_type
logger
=
init_logger
(
__name__
)
class
GritLMMeanPool
(
nn
.
Module
):
class
GritLMMeanPool
(
PoolingMethod
):
"""As `MeanPool`, but only includes non-instruction tokens."""
def
__init__
(
self
,
model_config
:
ModelConfig
):
...
...
@@ -141,16 +142,16 @@ class GritLMMeanPool(nn.Module):
return
instruction_len
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
,
"embed"
}
return
{
"embed"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
]
|
torch
.
Tensor
:
)
->
TokenPoolingMethodOutput
:
prompt_lens
=
pooling_metadata
.
prompt_lens
instr_lens
=
torch
.
tensor
(
[
...
...
@@ -178,7 +179,7 @@ class GritLMPooler(Pooler):
super
().
__init__
()
self
.
pooling
=
GritLMMeanPool
(
model_config
)
self
.
head
=
PoolerHead
(
PoolerNormalize
()
)
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
...
...
@@ -186,11 +187,18 @@ class GritLMPooler(Pooler):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
head
(
self
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerHeadOutput
:
return
self
.
activation
(
pooled_data
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
)
->
Token
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
...
...
vllm/model_executor/models/modernbert.py
View file @
b7036c87
...
...
@@ -19,12 +19,15 @@ from vllm.model_executor.layers.pooler import (
PoolingMethod
,
PoolingParamsUpdate
,
PoolingType
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
...
...
@@ -300,23 +303,25 @@ class ModernBertPooler(Pooler):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
_head
(
self
,
pooled_output
:
torch
.
Tensor
):
pooled_output
=
pooled_output
.
to
(
self
.
dense
.
weight
.
dtype
)
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_output
)))
def
forward
(
def
head
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
]
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
pooled_output
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
)
->
TokenPoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
if
isinstance
(
pooled_output
,
list
):
pooled_output
=
[
self
.
_head
(
output
)
for
output
in
pooled_output
]
else
:
pooled_output
=
self
.
_head
(
pooled_output
)
pooled_data
=
pooled_data
.
to
(
self
.
dense
.
weight
.
dtype
)
return
self
.
norm
(
self
.
act
(
self
.
dense
(
pooled_data
)))
return
pooled_output
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
@
default_pooling_type
(
"CLS"
)
...
...
vllm/v1/outputs.py
View file @
b7036c87
...
...
@@ -3,7 +3,7 @@
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
NamedTuple
from
typing
import
TYPE_CHECKING
,
NamedTuple
,
TypeAlias
import
numpy
as
np
import
torch
...
...
@@ -91,7 +91,9 @@ class LogprobsTensors(NamedTuple):
# [num_reqs, <dynamic>]
# The shape of each element depends on the pooler used
PoolerOutput
=
list
[
torch
.
Tensor
|
None
]
|
torch
.
Tensor
|
None
TokenPoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokensPoolerOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
PoolerOutput
:
TypeAlias
=
TokenPoolerOutput
|
TokensPoolerOutput
@
dataclass
...
...
vllm/v1/pool/metadata.py
View file @
b7036c87
...
...
@@ -90,6 +90,12 @@ class PoolingMetadata:
return
[
prompt_token_ids
[
i
,
:
num
]
for
i
,
num
in
enumerate
(
self
.
prompt_lens
)]
def
get_pooling_cursor
(
self
)
->
PoolingCursor
:
pooling_cursor
=
self
.
pooling_cursor
assert
pooling_cursor
is
not
None
,
"Should call `build_pooling_cursor` first"
return
pooling_cursor
def
build_pooling_cursor
(
self
,
num_scheduled_tokens_np
:
np
.
ndarray
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b7036c87
...
...
@@ -4680,7 +4680,7 @@ class GPUModelRunner(
for
task
in
supported_pooling_tasks
:
# Run a full batch with each task to ensure none of them OOMs
output
=
self
.
_dummy_pooler_run_task
(
hidden_states
,
task
)
output_size
[
task
]
=
sum
(
o
.
nbytes
for
o
in
output
)
output_size
[
task
]
=
sum
(
o
.
nbytes
for
o
in
output
if
o
is
not
None
)
del
output
# Allow GC
max_task
=
max
(
output_size
.
items
(),
key
=
lambda
x
:
x
[
1
])[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment