Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8ed39b9
Unverified
Commit
c8ed39b9
authored
Jan 09, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 09, 2026
Browse files
[Model] Reorganize pooling layers (#31973)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
02073280
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1215 additions
and
934 deletions
+1215
-934
.github/CODEOWNERS
.github/CODEOWNERS
+1
-1
tests/model_executor/test_model_load_with_params.py
tests/model_executor/test_model_load_with_params.py
+2
-1
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+2
-7
vllm/config/pooler.py
vllm/config/pooler.py
+4
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+0
-845
vllm/model_executor/layers/pooler/__init__.py
vllm/model_executor/layers/pooler/__init__.py
+5
-0
vllm/model_executor/layers/pooler/abstract.py
vllm/model_executor/layers/pooler/abstract.py
+39
-0
vllm/model_executor/layers/pooler/activations.py
vllm/model_executor/layers/pooler/activations.py
+162
-0
vllm/model_executor/layers/pooler/common.py
vllm/model_executor/layers/pooler/common.py
+27
-0
vllm/model_executor/layers/pooler/seqwise/__init__.py
vllm/model_executor/layers/pooler/seqwise/__init__.py
+45
-0
vllm/model_executor/layers/pooler/seqwise/heads.py
vllm/model_executor/layers/pooler/seqwise/heads.py
+157
-0
vllm/model_executor/layers/pooler/seqwise/methods.py
vllm/model_executor/layers/pooler/seqwise/methods.py
+93
-0
vllm/model_executor/layers/pooler/seqwise/poolers.py
vllm/model_executor/layers/pooler/seqwise/poolers.py
+106
-0
vllm/model_executor/layers/pooler/special.py
vllm/model_executor/layers/pooler/special.py
+128
-0
vllm/model_executor/layers/pooler/tokwise/__init__.py
vllm/model_executor/layers/pooler/tokwise/__init__.py
+39
-0
vllm/model_executor/layers/pooler/tokwise/heads.py
vllm/model_executor/layers/pooler/tokwise/heads.py
+142
-0
vllm/model_executor/layers/pooler/tokwise/methods.py
vllm/model_executor/layers/pooler/tokwise/methods.py
+124
-0
vllm/model_executor/layers/pooler/tokwise/poolers.py
vllm/model_executor/layers/pooler/tokwise/poolers.py
+106
-0
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+5
-23
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+28
-57
No files found.
.github/CODEOWNERS
View file @
c8ed39b9
...
...
@@ -153,7 +153,7 @@ mkdocs.yaml @hmellor
/vllm/entrypoints/pooling @noooop
/vllm/config/pooler.py @noooop
/vllm/pooling_params.py @noooop
/vllm/model_executor/layers/pooler
.py
@noooop
/vllm/model_executor/layers/pooler @noooop
# Security guide and policies
/docs/usage/security.md @russellb
...
...
tests/model_executor/test_model_load_with_params.py
View file @
c8ed39b9
...
...
@@ -5,7 +5,8 @@ import os
import
pytest
from
vllm.model_executor.layers.pooler
import
CLSPool
,
DispatchPooler
,
MeanPool
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.seqwise
import
CLSPool
,
MeanPool
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
from
vllm.model_executor.models.roberta
import
RobertaEmbeddingModel
from
vllm.platforms
import
current_platform
...
...
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
c8ed39b9
...
...
@@ -7,7 +7,7 @@ import torch
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -28,12 +28,7 @@ class MyGemma2Embedding(nn.Module):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
...
...
vllm/config/pooler.py
View file @
c8ed39b9
...
...
@@ -88,6 +88,10 @@ class PoolerConfig:
# raise deprecated warning for softmax and activation
self
.
use_activation
=
get_use_activation
(
self
)
def
get_pooling_type
(
self
)
->
PoolingTypeStr
:
assert
self
.
pooling_type
is
not
None
,
"Should be resolved by ModelConfig"
return
self
.
pooling_type
def
compute_hash
(
self
)
->
str
:
"""
WARNING: Whenever a new field is added to this config,
...
...
vllm/model_executor/layers/pooler.py
deleted
100644 → 0
View file @
02073280
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
,
Mapping
,
Set
from
dataclasses
import
dataclass
from
itertools
import
groupby
from
typing
import
TypeAlias
,
TypeVar
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
get_current_vllm_config
from
vllm.config.pooler
import
PoolerConfig
,
PoolingTypeStr
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.v1.outputs
import
PoolerOutput
,
TokenPoolerOutput
,
TokenwisePoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
logger
=
init_logger
(
__name__
)
PoolingFn
=
Callable
[
[
torch
.
Tensor
|
list
[
torch
.
Tensor
],
PoolingMetadata
],
torch
.
Tensor
|
list
[
torch
.
Tensor
],
]
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
TokenPoolingMethodOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokenwisePoolingMethodOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
TokenwisePoolingMethodOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
PoolingMethodOutput
:
TypeAlias
=
TokenPoolingMethodOutput
|
TokenwisePoolingMethodOutput
TokenPoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokenwisePoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
None
@
dataclass
(
frozen
=
True
)
class
ResolvedPoolingConfig
:
pooling_type
:
PoolingTypeStr
task
:
PoolingTask
@
classmethod
def
from_config
(
cls
,
task
:
PoolingTask
,
pooler_config
:
PoolerConfig
,
)
->
"ResolvedPoolingConfig"
:
assert
pooler_config
.
pooling_type
is
not
None
return
cls
(
task
=
task
,
pooling_type
=
pooler_config
.
pooling_type
)
@
dataclass
(
frozen
=
True
)
class
PoolingParamsUpdate
:
requires_token_ids
:
bool
=
False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def
apply
(
self
,
params
:
PoolingParams
)
->
None
:
params
.
requires_token_ids
=
self
.
requires_token_ids
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type
=
getattr
(
config
,
"problem_type"
,
""
)
if
problem_type
==
"regression"
:
return
PoolerIdentity
()
if
problem_type
==
"single_label_classification"
:
return
PoolerClassify
()
if
problem_type
==
"multi_label_classification"
:
return
PoolerMultiLabelClassify
()
return
PoolerClassify
()
def
get_cross_encoder_activation_function
(
config
:
PretrainedConfig
):
function_name
:
str
|
None
=
None
if
(
hasattr
(
config
,
"sentence_transformers"
)
and
"activation_fn"
in
config
.
sentence_transformers
):
function_name
=
config
.
sentence_transformers
[
"activation_fn"
]
elif
(
hasattr
(
config
,
"sbert_ce_default_activation_function"
)
and
config
.
sbert_ce_default_activation_function
is
not
None
):
function_name
=
config
.
sbert_ce_default_activation_function
if
function_name
is
not
None
:
assert
function_name
.
startswith
(
"torch.nn.modules."
),
(
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
fn
=
resolve_obj_by_qualname
(
function_name
)()
return
PoolerActivation
.
wraps
(
fn
)
return
PoolerClassify
()
class
PoolingMethod
(
nn
.
Module
,
ABC
):
@
staticmethod
def
from_pooling_type
(
pooling_type
:
PoolingTypeStr
)
->
"PoolingMethod"
:
if
pooling_type
==
"LAST"
:
return
LastPool
()
if
pooling_type
==
"ALL"
:
return
AllPool
()
if
pooling_type
==
"CLS"
:
return
CLSPool
()
if
pooling_type
==
"MEAN"
:
return
MeanPool
()
if
pooling_type
==
"STEP"
:
raise
ValueError
(
"'STEP' pooling is handled by StepPooler "
"and is not a standalone PoolingMethod."
)
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
!
r
}
"
)
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
raise
NotImplementedError
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolingMethodOutput
:
raise
NotImplementedError
class
CLSPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with CLS pooling"
)
return
hidden_states
[
pooling_cursor
.
first_token_indices_gpu
]
class
LastPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
class
AllPool
(
PoolingMethod
):
def
__init__
(
self
):
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
self
.
enable_chunked_prefill
=
(
vllm_config
.
scheduler_config
.
enable_chunked_prefill
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenwisePoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
hidden_states_all
=
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
()
)
hidden_states_lst
=
[
hidden_states_all
[
i
]
for
i
in
pooling_cursor
.
index
]
if
not
self
.
enable_chunked_prefill
:
return
hidden_states_lst
pooling_states
=
pooling_metadata
.
pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for
p
,
hs_chunk
in
zip
(
pooling_states
,
hidden_states_lst
):
p
.
hidden_states_cache
.
append
(
hs_chunk
)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list
=
list
[
torch
.
Tensor
|
None
]()
for
p
,
finished
in
zip
(
pooling_states
,
pooling_cursor
.
is_finished
()):
if
finished
:
hidden_states_cache
=
p
.
hidden_states_cache
if
len
(
hidden_states_cache
)
==
1
:
output_list
.
append
(
hidden_states_cache
[
0
])
else
:
output_list
.
append
(
torch
.
concat
(
hidden_states_cache
,
dim
=
0
))
p
.
clean
()
else
:
output_list
.
append
(
None
)
return
output_list
class
MeanPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with MEAN pooling"
)
prompt_lens
=
pooling_cursor
.
prompt_lens_cpu
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
start_indices
=
pooling_cursor
.
first_token_indices_gpu
end_indices
=
pooling_cursor
.
last_token_indices_gpu
return
(
cumsum
[
end_indices
]
-
cumsum
[
start_indices
]
+
hidden_states
[
start_indices
]
)
/
prompt_lens
.
unsqueeze
(
1
)
_T
=
TypeVar
(
"_T"
,
torch
.
Tensor
,
list
[
torch
.
Tensor
])
class
BasePoolerActivation
(
nn
.
Module
,
ABC
):
@
abstractmethod
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
raise
NotImplementedError
class
PoolerActivation
(
BasePoolerActivation
):
@
staticmethod
def
wraps
(
module
:
nn
.
Module
):
if
isinstance
(
module
,
nn
.
Identity
):
return
PoolerIdentity
()
if
isinstance
(
module
,
(
nn
.
Sigmoid
,
nn
.
Softmax
)):
return
PoolerClassify
()
return
LambdaPoolerActivation
(
module
)
@
abstractmethod
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
if
isinstance
(
pooled_data
,
list
):
return
[
self
.
forward_chunk
(
data
)
for
data
in
pooled_data
]
return
self
.
forward_chunk
(
pooled_data
)
class
PoolerIdentity
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
pooled_data
class
PoolerNormalize
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=-
1
)
class
PoolerMultiLabelClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
sigmoid
(
pooled_data
)
class
PoolerClassify
(
PoolerActivation
):
def
__init__
(
self
,
*
,
static_num_labels
:
bool
=
True
)
->
None
:
super
().
__init__
()
if
static_num_labels
:
vllm_config
=
get_current_vllm_config
()
self
.
num_labels
=
getattr
(
vllm_config
.
model_config
.
hf_config
,
"num_labels"
,
0
)
if
self
.
num_labels
==
0
:
logger
.
warning
(
"num_labels should be > 0 for classification"
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
else
:
self
.
num_labels
=
None
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_labels
=
(
self
.
num_labels
if
self
.
num_labels
is
not
None
else
pooled_data
.
shape
[
-
1
]
)
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
)
return
F
.
softmax
(
pooled_data
,
dim
=-
1
)
class
LambdaPoolerActivation
(
PoolerActivation
):
def
__init__
(
self
,
fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]):
super
().
__init__
()
self
.
fn
=
fn
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
fn
(
pooled_data
)
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
@
staticmethod
def
for_token_embed
(
pooler_config
:
PoolerConfig
):
head
=
TokenEmbeddingPoolerHead
()
if
pooler_config
.
pooling_type
==
"STEP"
:
return
StepPooler
(
head
=
head
)
return
AllPooler
(
head
=
head
)
@
staticmethod
def
for_token_classify
(
pooler_config
:
PoolerConfig
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
if
pooler_config
.
pooling_type
==
"STEP"
:
return
StepPooler
(
head
=
head
)
return
AllPooler
(
head
=
head
)
@
staticmethod
def
for_embed
(
pooler_config
:
PoolerConfig
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"embed"
,
pooler_config
=
pooler_config
,
)
pooling
=
PoolingMethod
.
from_pooling_type
(
resolved_config
.
pooling_type
)
head
=
EmbeddingPoolerHead
()
return
SimplePooler
(
pooling
=
pooling
,
head
=
head
)
@
staticmethod
def
for_classify
(
pooler_config
:
PoolerConfig
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"classify"
,
pooler_config
=
pooler_config
,
)
pooling
=
PoolingMethod
.
from_pooling_type
(
resolved_config
.
pooling_type
)
return
ClassifierPooler
(
pooling
=
pooling
,
classifier
=
classifier
,
act_fn
=
act_fn
,
)
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
"""Determine which pooling tasks are supported."""
raise
NotImplementedError
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
class
DummyPooler
(
Pooler
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"plugin"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
hidden_states
class
TokenPoolerHead
(
nn
.
Module
,
ABC
):
"""Applicable to pooling strategies that output one token."""
@
abstractmethod
def
forward
(
self
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerHeadOutput
:
raise
NotImplementedError
class
EmbeddingPoolerHead
(
TokenPoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
forward
(
self
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params
=
pooling_metadata
.
pooling_params
# for matryoshka representation
dimensions_list
=
[
pooling_param
.
dimensions
for
pooling_param
in
pooling_params
]
if
any
(
d
is
not
None
for
d
in
dimensions_list
):
# change the output dimension
assert
len
(
pooled_data
)
==
len
(
dimensions_list
)
if
len
(
set
(
dimensions_list
))
==
1
and
not
isinstance
(
pooled_data
,
list
):
# if all dimensions are the same
d
=
dimensions_list
[
0
]
pooled_data
=
pooled_data
[...,
:
d
]
else
:
pooled_data
=
[
vecs
if
d
is
None
else
vecs
[...,
:
d
]
for
vecs
,
d
in
zip
(
pooled_data
,
dimensions_list
)
]
# for normalize
flags
=
[
p
.
normalize
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
if
flags
[
0
]:
pooled_data
=
self
.
activation
(
pooled_data
)
else
:
pooled_data
=
[
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# pooled_data shape: [batchsize, embedding_dimension]
return
pooled_data
class
SimplePooler
(
Pooler
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
"""
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
TokenPoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerHeadOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
class
ClassifierPooler
(
Pooler
):
"""A pooling layer for classification tasks.
This layer does the following:
1. Applies a classification layer to the hidden states.
2. Optionally applies a pooler layer.
3. Applies an activation function to the output.
"""
@
staticmethod
def
act_fn_for_seq_cls
(
model_config
:
ModelConfig
):
return
get_classification_activation_function
(
model_config
.
hf_config
)
@
staticmethod
def
act_fn_for_cross_encoder
(
model_config
:
ModelConfig
):
return
get_cross_encoder_activation_function
(
model_config
.
hf_config
)
@
staticmethod
def
resolve_act_fn
(
model_config
:
ModelConfig
,
static_num_labels
:
bool
=
True
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
isinstance
(
act_fn
,
str
):
if
act_fn
==
"classify"
:
return
ClassifierPooler
.
act_fn_for_seq_cls
(
model_config
)
elif
act_fn
==
"score"
:
return
ClassifierPooler
.
act_fn_for_cross_encoder
(
model_config
)
else
:
raise
ValueError
(
f
"act_fn [
{
act_fn
=
}
] not supported."
)
elif
act_fn
is
None
:
return
PoolerClassify
(
static_num_labels
=
static_num_labels
)
else
:
assert
callable
(
act_fn
)
return
act_fn
def
__init__
(
self
,
pooling
:
PoolingFn
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
self
.
pooling
=
pooling
self
.
classifier
=
classifier
self
.
act_fn
=
self
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
)
self
.
logit_bias
:
float
|
None
=
(
vllm_config
.
model_config
.
pooler_config
.
logit_bias
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
# pooled_data shape: [batchsize, num_labels]
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
pooling_params
=
pooling_metadata
.
pooling_params
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
scores
=
self
.
act_fn
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
else
:
scores
=
[
self
.
act_fn
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# scores shape: [batchsize, num_labels]
return
scores
class
TokenwisePoolerHead
(
nn
.
Module
,
ABC
):
"""Applicable to pooling strategies that output multiple tokens."""
@
abstractmethod
def
forward
(
self
,
pooled_data
:
TokenwisePoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenwisePoolerHeadOutput
:
raise
NotImplementedError
class
TokenEmbeddingPoolerHead
(
TokenwisePoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
forward
(
self
,
pooled_data
:
TokenwisePoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenwisePoolerHeadOutput
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data
=
pooled_data
[...,
:
pooling_param
.
dimensions
]
# for normalize
if
pooling_param
.
normalize
:
pooled_data
=
self
.
activation
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
return
pooled_data
class
TokenClassifierPoolerHead
(
TokenwisePoolerHead
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
(
vllm_config
.
model_config
.
pooler_config
.
logit_bias
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
ClassifierPooler
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
def
forward
(
self
,
pooled_data
:
TokenwisePoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenwisePoolerHeadOutput
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
if
self
.
classifier
is
not
None
:
scores
=
self
.
classifier
(
pooled_data
)
else
:
scores
=
pooled_data
# scores shape: [n_token, num_labels]
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
if
pooling_param
.
use_activation
:
scores
=
self
.
activation
(
scores
)
# scores shape: [n_token, num_labels]
return
scores
class
AllPooler
(
Pooler
):
def
__init__
(
self
,
head
:
TokenwisePoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenwisePoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
return
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
class
StepPooler
(
Pooler
):
def
__init__
(
self
,
head
:
TokenwisePoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
head
=
head
def
extract_states
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
|
None
]:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
pooling_metadata
.
get_prompt_token_ids
()
pooling_params
=
pooling_metadata
.
pooling_params
pooled_data
=
list
[
torch
.
Tensor
|
None
]()
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
):
# for unfinished chunked prefill
if
data
is
None
:
pooled_data
.
append
(
data
)
continue
step_tag_id
=
pooling_param
.
step_tag_id
returned_token_ids
=
pooling_param
.
returned_token_ids
if
returned_token_ids
is
not
None
and
len
(
returned_token_ids
)
>
0
:
data
=
data
[:,
returned_token_ids
]
if
step_tag_id
is
not
None
:
data
=
data
[
token_id
==
step_tag_id
]
pooled_data
.
append
(
data
)
return
pooled_data
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenwisePoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
return
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
class
DispatchPooler
(
Pooler
):
"""Dispatches calls to a sub-pooler based on the pooling task."""
def
__init__
(
self
,
poolers_by_task
:
Mapping
[
PoolingTask
,
Pooler
])
->
None
:
super
().
__init__
()
for
task
,
pooler
in
poolers_by_task
.
items
():
if
task
not
in
pooler
.
get_supported_tasks
():
raise
ValueError
(
f
"
{
pooler
=
}
does not support
{
task
=
}
. "
f
"Supported tasks:
{
pooler
.
get_supported_tasks
()
}
"
)
self
.
poolers_by_task
=
poolers_by_task
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
set
(
self
.
poolers_by_task
)
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
poolers_by_task
[
task
].
get_pooling_updates
(
task
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
outputs
=
list
[
torch
.
Tensor
|
None
]()
offset
=
0
for
task
,
group
in
groupby
(
pooling_metadata
.
tasks
):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
raise
ValueError
(
f
"Unsupported task:
{
task
}
"
f
"Supported tasks:
{
self
.
get_supported_tasks
()
}
"
)
num_items
=
len
(
list
(
group
))
group_output
:
PoolerOutput
=
pooler
(
hidden_states
,
pooling_metadata
[
offset
:
offset
+
num_items
],
)
outputs
.
extend
(
group_output
)
offset
+=
num_items
return
outputs
def
extra_repr
(
self
)
->
str
:
s
=
f
"supported_task=
{
self
.
get_supported_tasks
()
}
"
return
s
vllm/model_executor/layers/pooler/__init__.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.abstract
import
*
from
.common
import
*
from
.special
import
*
vllm/model_executor/layers/pooler/abstract.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
import
torch
import
torch.nn
as
nn
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
PoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.common
import
PoolingParamsUpdate
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
"""Determine which pooling tasks are supported."""
raise
NotImplementedError
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
__all__
=
[
"Pooler"
]
vllm/model_executor/layers/pooler/activations.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
from
typing
import
TypeVar
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
logger
=
init_logger
(
__name__
)
def
get_classification_act_fn
(
config
:
PretrainedConfig
,
)
->
"PoolerActivation"
:
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type
=
getattr
(
config
,
"problem_type"
,
""
)
if
problem_type
==
"regression"
:
return
PoolerIdentity
()
if
problem_type
==
"single_label_classification"
:
return
PoolerClassify
()
if
problem_type
==
"multi_label_classification"
:
return
PoolerMultiLabelClassify
()
return
PoolerClassify
()
def
get_cross_encoder_act_fn
(
config
:
PretrainedConfig
,
)
->
"PoolerActivation"
:
function_name
:
str
|
None
=
None
if
(
hasattr
(
config
,
"sentence_transformers"
)
and
"activation_fn"
in
config
.
sentence_transformers
):
function_name
=
config
.
sentence_transformers
[
"activation_fn"
]
elif
(
hasattr
(
config
,
"sbert_ce_default_activation_function"
)
and
config
.
sbert_ce_default_activation_function
is
not
None
):
function_name
=
config
.
sbert_ce_default_activation_function
if
function_name
is
not
None
:
assert
function_name
.
startswith
(
"torch.nn.modules."
),
(
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
fn
=
resolve_obj_by_qualname
(
function_name
)()
return
PoolerActivation
.
wraps
(
fn
)
return
PoolerClassify
()
def
resolve_classifier_act_fn
(
model_config
:
ModelConfig
,
static_num_labels
:
bool
=
True
,
act_fn
:
"PoolerActivation | str | None"
=
None
,
):
if
isinstance
(
act_fn
,
str
):
if
act_fn
==
"classify"
:
return
get_classification_act_fn
(
model_config
.
hf_config
)
if
act_fn
==
"score"
:
return
get_cross_encoder_act_fn
(
model_config
.
hf_config
)
raise
ValueError
(
f
"act_fn [
{
act_fn
=
}
] not supported."
)
if
act_fn
is
None
:
return
PoolerClassify
(
static_num_labels
=
static_num_labels
)
assert
callable
(
act_fn
)
return
act_fn
_T
=
TypeVar
(
"_T"
,
torch
.
Tensor
,
list
[
torch
.
Tensor
])
class
PoolerActivation
(
nn
.
Module
,
ABC
):
@
staticmethod
def
wraps
(
module
:
nn
.
Module
):
if
isinstance
(
module
,
nn
.
Identity
):
return
PoolerIdentity
()
if
isinstance
(
module
,
(
nn
.
Sigmoid
,
nn
.
Softmax
)):
return
PoolerClassify
()
return
LambdaPoolerActivation
(
module
)
@
abstractmethod
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
if
isinstance
(
pooled_data
,
list
):
return
[
self
.
forward_chunk
(
data
)
for
data
in
pooled_data
]
return
self
.
forward_chunk
(
pooled_data
)
class
PoolerIdentity
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
pooled_data
class
PoolerNormalize
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=-
1
)
class
PoolerMultiLabelClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
sigmoid
(
pooled_data
)
class
PoolerClassify
(
PoolerActivation
):
def
__init__
(
self
,
*
,
static_num_labels
:
bool
=
True
)
->
None
:
super
().
__init__
()
if
static_num_labels
:
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
num_labels
=
getattr
(
model_config
.
hf_config
,
"num_labels"
,
0
)
else
:
num_labels
=
None
if
num_labels
==
0
:
logger
.
warning
(
"num_labels should be > 0 for classification "
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
self
.
num_labels
=
num_labels
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_labels
=
self
.
num_labels
if
num_labels
is
None
:
num_labels
=
pooled_data
.
shape
[
-
1
]
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
)
return
F
.
softmax
(
pooled_data
,
dim
=-
1
)
class
LambdaPoolerActivation
(
PoolerActivation
):
def
__init__
(
self
,
fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]):
super
().
__init__
()
self
.
fn
=
fn
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
fn
(
pooled_data
)
vllm/model_executor/layers/pooler/common.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
import
torch
from
vllm.pooling_params
import
PoolingParams
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
@
dataclass
(
frozen
=
True
)
class
PoolingParamsUpdate
:
requires_token_ids
:
bool
=
False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def
__or__
(
self
,
other
:
"PoolingParamsUpdate"
)
->
"PoolingParamsUpdate"
:
return
PoolingParamsUpdate
(
requires_token_ids
=
self
.
requires_token_ids
or
other
.
requires_token_ids
,
)
def
apply
(
self
,
params
:
PoolingParams
)
->
None
:
params
.
requires_token_ids
=
self
.
requires_token_ids
__all__
=
[
"ClassifierFn"
,
"PoolingParamsUpdate"
]
vllm/model_executor/layers/pooler/seqwise/__init__.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Poolers that produce an output aggregating all tokens in the sequence."""
from
.heads
import
(
ClassifierPoolerHead
,
EmbeddingPoolerHead
,
SequencePoolerHead
,
SequencePoolerHeadOutput
,
)
from
.methods
import
(
CLSPool
,
LastPool
,
MeanPool
,
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
from
.poolers
import
(
SequencePooler
,
SequencePoolerOutput
,
SequencePoolingFn
,
SequencePoolingHeadFn
,
pooler_for_classify
,
pooler_for_embed
,
)
__all__
=
[
"SequencePoolerHead"
,
"SequencePoolerHeadOutput"
,
"ClassifierPoolerHead"
,
"EmbeddingPoolerHead"
,
"SequencePoolingMethod"
,
"SequencePoolingMethodOutput"
,
"CLSPool"
,
"LastPool"
,
"MeanPool"
,
"get_seq_pooling_method"
,
"SequencePooler"
,
"SequencePoolingFn"
,
"SequencePoolingHeadFn"
,
"SequencePoolerOutput"
,
"pooler_for_classify"
,
"pooler_for_embed"
,
]
vllm/model_executor/layers/pooler/seqwise/heads.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
typing
import
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.methods
import
SequencePoolingMethodOutput
SequencePoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
class
SequencePoolerHead
(
nn
.
Module
,
ABC
):
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
raise
NotImplementedError
@
abstractmethod
def
forward
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
raise
NotImplementedError
class
EmbeddingPoolerHead
(
SequencePoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
projector
=
_load_st_projector
(
model_config
)
self
.
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"embed"
}
def
forward
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
# for matryoshka representation
dimensions_list
=
[
pooling_param
.
dimensions
for
pooling_param
in
pooling_params
]
if
any
(
d
is
not
None
for
d
in
dimensions_list
):
# change the output dimension
assert
len
(
pooled_data
)
==
len
(
dimensions_list
)
if
len
(
set
(
dimensions_list
))
==
1
and
not
isinstance
(
pooled_data
,
list
):
# if all dimensions are the same
d
=
dimensions_list
[
0
]
pooled_data
=
pooled_data
[...,
:
d
]
else
:
pooled_data
=
[
vecs
if
d
is
None
else
vecs
[...,
:
d
]
for
vecs
,
d
in
zip
(
pooled_data
,
dimensions_list
)
]
# for normalize
flags
=
[
p
.
normalize
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
if
flags
[
0
]:
pooled_data
=
self
.
activation
(
pooled_data
)
else
:
pooled_data
=
[
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# pooled_data shape: [batchsize, embedding_dimension]
return
pooled_data
class
ClassifierPoolerHead
(
SequencePoolerHead
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
model_config
.
head_dtype
self
.
act_fn
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
def
forward
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
# pooled_data shape: [batchsize, num_labels]
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
scores
=
self
.
act_fn
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
else
:
scores
=
[
self
.
act_fn
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# scores shape: [batchsize, num_labels]
return
scores
vllm/model_executor/layers/pooler/seqwise/methods.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
typing
import
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config.pooler
import
PoolingTypeStr
from
vllm.model_executor.layers.pooler
import
PoolingParamsUpdate
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
SequencePoolingMethodOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
class
SequencePoolingMethod
(
nn
.
Module
,
ABC
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolingMethodOutput
:
raise
NotImplementedError
class
CLSPool
(
SequencePoolingMethod
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with CLS pooling"
)
return
hidden_states
[
pooling_cursor
.
first_token_indices_gpu
]
class
LastPool
(
SequencePoolingMethod
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
class
MeanPool
(
SequencePoolingMethod
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with MEAN pooling"
)
prompt_lens
=
pooling_cursor
.
prompt_lens_cpu
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
start_indices
=
pooling_cursor
.
first_token_indices_gpu
end_indices
=
pooling_cursor
.
last_token_indices_gpu
return
(
cumsum
[
end_indices
]
-
cumsum
[
start_indices
]
+
hidden_states
[
start_indices
]
)
/
prompt_lens
.
unsqueeze
(
1
)
def
get_seq_pooling_method
(
pooling_type
:
PoolingTypeStr
|
str
):
if
pooling_type
==
"LAST"
:
return
LastPool
()
if
pooling_type
==
"CLS"
:
return
CLSPool
()
if
pooling_type
==
"MEAN"
:
return
MeanPool
()
raise
NotImplementedError
(
f
"Unknown sequence pooling type:
{
pooling_type
!
r
}
"
)
vllm/model_executor/layers/pooler/seqwise/poolers.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
,
Set
from
typing
import
TypeAlias
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.activations
import
PoolerActivation
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.heads
import
(
ClassifierPoolerHead
,
EmbeddingPoolerHead
,
SequencePoolerHead
,
SequencePoolerHeadOutput
,
)
from
.methods
import
(
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
SequencePoolingFn
:
TypeAlias
=
Callable
[
[
torch
.
Tensor
,
PoolingMetadata
],
SequencePoolingMethodOutput
,
]
SequencePoolingHeadFn
:
TypeAlias
=
Callable
[
[
SequencePoolingMethodOutput
,
PoolingMetadata
],
SequencePoolerHeadOutput
,
]
SequencePoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
class
SequencePooler
(
Pooler
):
"""
A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Postprocesses the output based on pooling head.
3. Returns structured results as `PoolerOutput`.
"""
def
__init__
(
self
,
pooling
:
SequencePoolingMethod
|
SequencePoolingFn
,
head
:
SequencePoolerHead
|
SequencePoolingHeadFn
,
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
tasks
=
set
(
POOLING_TASKS
)
if
isinstance
(
self
.
pooling
,
SequencePoolingMethod
):
tasks
&=
self
.
pooling
.
get_supported_tasks
()
if
isinstance
(
self
.
head
,
SequencePoolerHead
):
tasks
&=
self
.
head
.
get_supported_tasks
()
return
tasks
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
updates
=
PoolingParamsUpdate
()
if
isinstance
(
self
.
pooling
,
SequencePoolingMethod
):
updates
|=
self
.
pooling
.
get_pooling_updates
(
task
)
return
updates
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
def
pooler_for_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_pooling_type
())
head
=
EmbeddingPoolerHead
()
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
def
pooler_for_classify
(
pooler_config
:
PoolerConfig
,
*
,
pooling
:
SequencePoolingMethod
|
SequencePoolingFn
|
None
=
None
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
pooling
is
None
:
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_pooling_type
())
head
=
ClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
vllm/model_executor/layers/pooler/special.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Mapping
,
Set
from
itertools
import
groupby
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.model_executor.layers.pooler
import
PoolingParamsUpdate
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.abstract
import
Pooler
,
PoolerOutput
from
.common
import
ClassifierFn
from
.seqwise
import
(
SequencePoolingFn
,
SequencePoolingMethod
,
pooler_for_classify
,
pooler_for_embed
,
)
from
.tokwise
import
AllPool
,
pooler_for_token_classify
,
pooler_for_token_embed
class
DispatchPooler
(
Pooler
):
"""Dispatches calls to a sub-pooler based on the pooling task."""
@
classmethod
def
for_embedding
(
cls
,
pooler_config
:
PoolerConfig
):
return
cls
(
{
"token_embed"
:
pooler_for_token_embed
(
pooler_config
),
"embed"
:
pooler_for_embed
(
pooler_config
),
},
)
@
classmethod
def
for_seq_cls
(
cls
,
pooler_config
:
PoolerConfig
,
*
,
pooling
:
SequencePoolingMethod
|
SequencePoolingFn
|
None
=
None
,
classifier
:
ClassifierFn
|
None
=
None
,
):
return
cls
(
{
"token_classify"
:
pooler_for_token_classify
(
pooler_config
,
pooling
=
AllPool
(),
classifier
=
classifier
,
),
"classify"
:
pooler_for_classify
(
pooler_config
,
pooling
=
pooling
,
classifier
=
classifier
,
act_fn
=
"classify"
,
),
"score"
:
pooler_for_classify
(
pooler_config
,
pooling
=
pooling
,
classifier
=
classifier
,
act_fn
=
"score"
,
),
}
)
def
__init__
(
self
,
poolers_by_task
:
Mapping
[
PoolingTask
,
Pooler
])
->
None
:
super
().
__init__
()
for
task
,
pooler
in
poolers_by_task
.
items
():
if
task
not
in
pooler
.
get_supported_tasks
():
raise
ValueError
(
f
"
{
pooler
=
}
does not support
{
task
=
}
. "
f
"Supported tasks:
{
pooler
.
get_supported_tasks
()
}
"
)
self
.
poolers_by_task
=
poolers_by_task
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
set
(
self
.
poolers_by_task
)
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
poolers_by_task
[
task
].
get_pooling_updates
(
task
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
outputs
=
list
[
torch
.
Tensor
|
None
]()
offset
=
0
for
task
,
group
in
groupby
(
pooling_metadata
.
tasks
):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
raise
ValueError
(
f
"Unsupported task:
{
task
!
r
}
"
f
"Supported tasks:
{
self
.
get_supported_tasks
()
}
"
)
num_items
=
len
(
list
(
group
))
group_output
:
PoolerOutput
=
pooler
(
hidden_states
,
pooling_metadata
[
offset
:
offset
+
num_items
],
)
outputs
.
extend
(
group_output
)
offset
+=
num_items
return
outputs
def
extra_repr
(
self
)
->
str
:
s
=
f
"supported_task=
{
self
.
get_supported_tasks
()
}
"
return
s
class
IdentityPooler
(
Pooler
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"plugin"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
hidden_states
__all__
=
[
"DispatchPooler"
,
"IdentityPooler"
]
vllm/model_executor/layers/pooler/tokwise/__init__.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Poolers that produce an output for each token in the sequence."""
from
.heads
import
(
TokenClassifierPoolerHead
,
TokenEmbeddingPoolerHead
,
TokenPoolerHead
,
TokenPoolerHeadOutputItem
,
)
from
.methods
import
(
AllPool
,
StepPool
,
TokenPoolingMethod
,
TokenPoolingMethodOutputItem
,
get_tok_pooling_method
,
)
from
.poolers
import
(
TokenPooler
,
TokenPoolerOutput
,
pooler_for_token_classify
,
pooler_for_token_embed
,
)
__all__
=
[
"TokenPoolerHead"
,
"TokenPoolerHeadOutputItem"
,
"TokenClassifierPoolerHead"
,
"TokenEmbeddingPoolerHead"
,
"TokenPoolingMethod"
,
"TokenPoolingMethodOutputItem"
,
"AllPool"
,
"StepPool"
,
"get_tok_pooling_method"
,
"TokenPooler"
,
"TokenPoolerOutput"
,
"pooler_for_token_classify"
,
"pooler_for_token_embed"
,
]
vllm/model_executor/layers/pooler/tokwise/heads.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
typing
import
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.methods
import
TokenPoolingMethodOutputItem
TokenPoolerHeadOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
class
TokenPoolerHead
(
nn
.
Module
,
ABC
):
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
raise
NotImplementedError
@
abstractmethod
def
forward_chunk
(
self
,
pooled_data
:
TokenPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenPoolerHeadOutputItem
:
raise
NotImplementedError
def
forward
(
self
,
pooled_data
:
list
[
TokenPoolingMethodOutputItem
],
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
TokenPoolerHeadOutputItem
]:
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
return
[
self
.
forward_chunk
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
class
TokenEmbeddingPoolerHead
(
TokenPoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
projector
=
_load_st_projector
(
model_config
)
self
.
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
}
def
forward_chunk
(
self
,
pooled_data
:
TokenPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenPoolerHeadOutputItem
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data
=
pooled_data
[...,
:
pooling_param
.
dimensions
]
# for normalize
if
pooling_param
.
normalize
:
pooled_data
=
self
.
activation
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
return
pooled_data
class
TokenClassifierPoolerHead
(
TokenPoolerHead
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
model_config
.
head_dtype
self
.
act_fn
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_classify"
}
def
forward_chunk
(
self
,
pooled_data
:
TokenPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenPoolerHeadOutputItem
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
if
self
.
classifier
is
not
None
:
scores
=
self
.
classifier
(
pooled_data
)
else
:
scores
=
pooled_data
# scores shape: [n_token, num_labels]
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
if
pooling_param
.
use_activation
:
scores
=
self
.
act_fn
(
scores
)
# scores shape: [n_token, num_labels]
return
scores
vllm/model_executor/layers/pooler/tokwise/methods.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
typing
import
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.config.pooler
import
PoolingTypeStr
from
vllm.model_executor.layers.pooler
import
PoolingParamsUpdate
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
TokenPoolingMethodOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
class
TokenPoolingMethod
(
nn
.
Module
,
ABC
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
TokenPoolingMethodOutputItem
]:
raise
NotImplementedError
class
AllPool
(
TokenPoolingMethod
):
def
__init__
(
self
):
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
scheduler_config
=
vllm_config
.
scheduler_config
self
.
enable_chunked_prefill
=
scheduler_config
.
enable_chunked_prefill
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
TokenPoolingMethodOutputItem
]:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
hidden_states_all
=
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
()
)
hidden_states_lst
=
[
hidden_states_all
[
i
]
for
i
in
pooling_cursor
.
index
]
if
not
self
.
enable_chunked_prefill
:
return
hidden_states_lst
pooling_states
=
pooling_metadata
.
pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for
p
,
hs_chunk
in
zip
(
pooling_states
,
hidden_states_lst
):
p
.
hidden_states_cache
.
append
(
hs_chunk
)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list
=
list
[
TokenPoolingMethodOutputItem
]()
for
p
,
finished
in
zip
(
pooling_states
,
pooling_cursor
.
is_finished
()):
if
finished
:
hidden_states_cache
=
p
.
hidden_states_cache
if
len
(
hidden_states_cache
)
==
1
:
output_list
.
append
(
hidden_states_cache
[
0
])
else
:
output_list
.
append
(
torch
.
concat
(
hidden_states_cache
,
dim
=
0
))
p
.
clean
()
else
:
output_list
.
append
(
None
)
return
output_list
class
StepPool
(
AllPool
):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
TokenPoolingMethodOutputItem
]:
pooled_data_lst
=
super
().
forward
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
pooling_metadata
.
get_prompt_token_ids
()
pooling_params
=
pooling_metadata
.
pooling_params
pooled_data
=
list
[
torch
.
Tensor
|
None
]()
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
):
# for unfinished chunked prefill
if
data
is
None
:
pass
else
:
step_tag_id
=
pooling_param
.
step_tag_id
returned_token_ids
=
pooling_param
.
returned_token_ids
if
returned_token_ids
is
not
None
and
len
(
returned_token_ids
)
>
0
:
data
=
data
[:,
returned_token_ids
]
if
step_tag_id
is
not
None
:
data
=
data
[
token_id
==
step_tag_id
]
pooled_data
.
append
(
data
)
return
pooled_data
def
get_tok_pooling_method
(
pooling_type
:
PoolingTypeStr
|
str
):
if
pooling_type
==
"ALL"
:
return
AllPool
()
if
pooling_type
==
"STEP"
:
return
StepPool
()
# TODO: Separate seq and tok pooling types so we don't need this fallback
return
AllPool
()
raise
NotImplementedError
(
f
"Unknown tokenwise pooling type:
{
pooling_type
!
r
}
"
)
vllm/model_executor/layers/pooler/tokwise/poolers.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
,
Set
from
typing
import
TypeAlias
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.activations
import
PoolerActivation
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.heads
import
(
TokenClassifierPoolerHead
,
TokenEmbeddingPoolerHead
,
TokenPoolerHead
,
TokenPoolerHeadOutputItem
,
)
from
.methods
import
(
TokenPoolingMethod
,
TokenPoolingMethodOutputItem
,
get_tok_pooling_method
,
)
TokenPoolingFn
:
TypeAlias
=
Callable
[
[
torch
.
Tensor
,
PoolingMetadata
],
list
[
TokenPoolingMethodOutputItem
],
]
TokenPoolingHeadFn
:
TypeAlias
=
Callable
[
[
list
[
TokenPoolingMethodOutputItem
],
PoolingMetadata
],
list
[
TokenPoolerHeadOutputItem
],
]
TokenPoolerOutput
:
TypeAlias
=
list
[
torch
.
Tensor
|
None
]
class
TokenPooler
(
Pooler
):
"""
A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Postprocesses the output based on pooling head.
3. Returns structured results as `PoolerOutput`.
"""
def
__init__
(
self
,
pooling
:
TokenPoolingMethod
|
TokenPoolingFn
,
head
:
TokenPoolerHead
|
TokenPoolingHeadFn
,
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
tasks
=
set
(
POOLING_TASKS
)
if
isinstance
(
self
.
pooling
,
TokenPoolingMethod
):
tasks
&=
self
.
pooling
.
get_supported_tasks
()
if
isinstance
(
self
.
head
,
TokenPoolerHead
):
tasks
&=
self
.
head
.
get_supported_tasks
()
return
tasks
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
updates
=
PoolingParamsUpdate
()
if
isinstance
(
self
.
pooling
,
TokenPoolingMethod
):
updates
|=
self
.
pooling
.
get_pooling_updates
(
task
)
return
updates
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
def
pooler_for_token_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_pooling_type
())
head
=
TokenEmbeddingPoolerHead
()
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
def
pooler_for_token_classify
(
pooler_config
:
PoolerConfig
,
*
,
pooling
:
TokenPoolingMethod
|
TokenPoolingFn
|
None
=
None
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
pooling
is
None
:
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_pooling_type
())
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
vllm/model_executor/models/adapters.py
View file @
c8ed39b9
...
...
@@ -252,19 +252,14 @@ def as_embedding_model(cls: _T) -> _T:
return
cls
# Lazy import
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
class
ModelForEmbedding
(
_create_pooling_model_cls
(
cls
)):
def
_init_pooler
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
},
)
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
ModelForEmbedding
.
__name__
=
_get_pooling_model_name
(
cls
.
__name__
,
"ForEmbedding"
)
...
...
@@ -289,10 +284,7 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.pooler
import
(
DispatchPooler
,
Pooler
,
)
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
.utils
import
maybe_prefix
...
...
@@ -318,18 +310,8 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
classifier
=
self
.
score
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/bert.py
View file @
c8ed39b9
...
...
@@ -18,19 +18,25 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
)
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
DispatchPooler
,
Pooler
,
PoolingMethod
,
PoolingParamsUpdate
,
TokenPoolerHeadOutput
,
TokenPoolingMethodOutput
,
)
from
vllm.model_executor.layers.pooler.seqwise
import
(
CLSPool
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolerOutput
,
SequencePoolingMethodOutput
,
)
from
vllm.model_executor.layers.pooler.tokwise
import
(
pooler_for_token_classify
,
pooler_for_token_embed
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
...
...
@@ -85,25 +91,21 @@ class BertEmbedding(nn.Module):
return
embeddings
class
BertPooler
(
Pooler
):
class
BertPooler
(
Sequence
Pooler
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
super
().
__init__
(
pooling
=
CLSPool
(),
head
=
self
.
head
,
)
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
"CLS"
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
head
(
self
,
pooled_data
:
Token
PoolingMethodOutput
,
pooled_data
:
Sequence
PoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
PoolerHeadOutput
:
)
->
Sequence
PoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
...
...
@@ -111,15 +113,6 @@ class BertPooler(Pooler):
pooled_data
=
self
.
activation
(
pooled_data
)
return
pooled_data
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
...
@@ -524,12 +517,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
)
def
_build_pooler
(
self
,
pooler_config
:
PoolerConfig
)
->
Pooler
:
return
DispatchPooler
(
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
return
DispatchPooler
.
for_embedding
(
pooler_config
)
# Here we encode the token type ids together with the input ids.
...
...
@@ -620,6 +608,7 @@ class SPLADESparsePooler(Pooler):
remove_cls_sep
:
bool
=
True
,
):
super
().
__init__
()
assert
pooling
in
(
"max"
,
"sum"
)
self
.
mlm_head
=
mlm_head
self
.
cls_token_id
=
cls_token_id
...
...
@@ -637,10 +626,8 @@ class SPLADESparsePooler(Pooler):
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
:
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
and
hidden_states
.
dim
()
==
2
lens_tensor
:
torch
.
Tensor
=
pooling_metadata
.
prompt_lens
)
->
SequencePoolerOutput
:
lens_tensor
=
pooling_metadata
.
prompt_lens
lens
:
list
[
int
]
=
lens_tensor
.
tolist
()
B
:
int
=
len
(
lens
)
...
...
@@ -729,7 +716,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return
DispatchPooler
(
{
"token_embed"
:
P
ooler
.
for_token_embed
(
pooler_config
),
"token_embed"
:
p
ooler
_
for_token_embed
(
pooler_config
),
"embed"
:
SPLADESparsePooler
(
mlm_head
=
self
.
mlm_head
,
cls_token_id
=
cls_id
,
...
...
@@ -824,20 +811,10 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
act_fn
=
"classify"
,
),
"score"
:
ClassifierPooler
(
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
pooler_config
,
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -891,13 +868,7 @@ class BertForTokenClassification(nn.Module):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
=
pooler_config
),
}
)
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
bert
.
embed_input_ids
(
input_ids
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment