Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c8ed39b9
Unverified
Commit
c8ed39b9
authored
Jan 09, 2026
by
Cyrus Leung
Committed by
GitHub
Jan 09, 2026
Browse files
[Model] Reorganize pooling layers (#31973)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
02073280
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1215 additions
and
934 deletions
+1215
-934
.github/CODEOWNERS
.github/CODEOWNERS
+1
-1
tests/model_executor/test_model_load_with_params.py
tests/model_executor/test_model_load_with_params.py
+2
-1
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
...dd_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
+2
-7
vllm/config/pooler.py
vllm/config/pooler.py
+4
-0
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+0
-845
vllm/model_executor/layers/pooler/__init__.py
vllm/model_executor/layers/pooler/__init__.py
+5
-0
vllm/model_executor/layers/pooler/abstract.py
vllm/model_executor/layers/pooler/abstract.py
+39
-0
vllm/model_executor/layers/pooler/activations.py
vllm/model_executor/layers/pooler/activations.py
+162
-0
vllm/model_executor/layers/pooler/common.py
vllm/model_executor/layers/pooler/common.py
+27
-0
vllm/model_executor/layers/pooler/seqwise/__init__.py
vllm/model_executor/layers/pooler/seqwise/__init__.py
+45
-0
vllm/model_executor/layers/pooler/seqwise/heads.py
vllm/model_executor/layers/pooler/seqwise/heads.py
+157
-0
vllm/model_executor/layers/pooler/seqwise/methods.py
vllm/model_executor/layers/pooler/seqwise/methods.py
+93
-0
vllm/model_executor/layers/pooler/seqwise/poolers.py
vllm/model_executor/layers/pooler/seqwise/poolers.py
+106
-0
vllm/model_executor/layers/pooler/special.py
vllm/model_executor/layers/pooler/special.py
+128
-0
vllm/model_executor/layers/pooler/tokwise/__init__.py
vllm/model_executor/layers/pooler/tokwise/__init__.py
+39
-0
vllm/model_executor/layers/pooler/tokwise/heads.py
vllm/model_executor/layers/pooler/tokwise/heads.py
+142
-0
vllm/model_executor/layers/pooler/tokwise/methods.py
vllm/model_executor/layers/pooler/tokwise/methods.py
+124
-0
vllm/model_executor/layers/pooler/tokwise/poolers.py
vllm/model_executor/layers/pooler/tokwise/poolers.py
+106
-0
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+5
-23
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+28
-57
No files found.
.github/CODEOWNERS
View file @
c8ed39b9
...
@@ -153,7 +153,7 @@ mkdocs.yaml @hmellor
...
@@ -153,7 +153,7 @@ mkdocs.yaml @hmellor
/vllm/entrypoints/pooling @noooop
/vllm/entrypoints/pooling @noooop
/vllm/config/pooler.py @noooop
/vllm/config/pooler.py @noooop
/vllm/pooling_params.py @noooop
/vllm/pooling_params.py @noooop
/vllm/model_executor/layers/pooler
.py
@noooop
/vllm/model_executor/layers/pooler @noooop
# Security guide and policies
# Security guide and policies
/docs/usage/security.md @russellb
/docs/usage/security.md @russellb
...
...
tests/model_executor/test_model_load_with_params.py
View file @
c8ed39b9
...
@@ -5,7 +5,8 @@ import os
...
@@ -5,7 +5,8 @@ import os
import
pytest
import
pytest
from
vllm.model_executor.layers.pooler
import
CLSPool
,
DispatchPooler
,
MeanPool
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.layers.pooler.seqwise
import
CLSPool
,
MeanPool
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
from
vllm.model_executor.models.bert
import
BertEmbeddingModel
from
vllm.model_executor.models.roberta
import
RobertaEmbeddingModel
from
vllm.model_executor.models.roberta
import
RobertaEmbeddingModel
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
...
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py
View file @
c8ed39b9
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.gemma2
import
Gemma2Model
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.model_executor.models.utils
import
WeightsMapper
,
maybe_prefix
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -28,12 +28,7 @@ class MyGemma2Embedding(nn.Module):
...
@@ -28,12 +28,7 @@ class MyGemma2Embedding(nn.Module):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
self
.
model
.
make_empty_intermediate_tensors
...
...
vllm/config/pooler.py
View file @
c8ed39b9
...
@@ -88,6 +88,10 @@ class PoolerConfig:
...
@@ -88,6 +88,10 @@ class PoolerConfig:
# raise deprecated warning for softmax and activation
# raise deprecated warning for softmax and activation
self
.
use_activation
=
get_use_activation
(
self
)
self
.
use_activation
=
get_use_activation
(
self
)
def
get_pooling_type
(
self
)
->
PoolingTypeStr
:
assert
self
.
pooling_type
is
not
None
,
"Should be resolved by ModelConfig"
return
self
.
pooling_type
def
compute_hash
(
self
)
->
str
:
def
compute_hash
(
self
)
->
str
:
"""
"""
WARNING: Whenever a new field is added to this config,
WARNING: Whenever a new field is added to this config,
...
...
vllm/model_executor/layers/pooler.py
deleted
100644 → 0
View file @
02073280
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
,
Mapping
,
Set
from
dataclasses
import
dataclass
from
itertools
import
groupby
from
typing
import
TypeAlias
,
TypeVar
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
get_current_vllm_config
from
vllm.config.pooler
import
PoolerConfig
,
PoolingTypeStr
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
from
vllm.v1.outputs
import
PoolerOutput
,
TokenPoolerOutput
,
TokenwisePoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
logger
=
init_logger
(
__name__
)
PoolingFn
=
Callable
[
[
torch
.
Tensor
|
list
[
torch
.
Tensor
],
PoolingMetadata
],
torch
.
Tensor
|
list
[
torch
.
Tensor
],
]
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
TokenPoolingMethodOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokenwisePoolingMethodOutput
:
TypeAlias
=
list
[
torch
.
Tensor
]
|
list
[
torch
.
Tensor
|
None
]
TokenwisePoolingMethodOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
PoolingMethodOutput
:
TypeAlias
=
TokenPoolingMethodOutput
|
TokenwisePoolingMethodOutput
TokenPoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
TokenwisePoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
None
@
dataclass
(
frozen
=
True
)
class
ResolvedPoolingConfig
:
pooling_type
:
PoolingTypeStr
task
:
PoolingTask
@
classmethod
def
from_config
(
cls
,
task
:
PoolingTask
,
pooler_config
:
PoolerConfig
,
)
->
"ResolvedPoolingConfig"
:
assert
pooler_config
.
pooling_type
is
not
None
return
cls
(
task
=
task
,
pooling_type
=
pooler_config
.
pooling_type
)
@
dataclass
(
frozen
=
True
)
class
PoolingParamsUpdate
:
requires_token_ids
:
bool
=
False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def
apply
(
self
,
params
:
PoolingParams
)
->
None
:
params
.
requires_token_ids
=
self
.
requires_token_ids
def
get_classification_activation_function
(
config
:
PretrainedConfig
):
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type
=
getattr
(
config
,
"problem_type"
,
""
)
if
problem_type
==
"regression"
:
return
PoolerIdentity
()
if
problem_type
==
"single_label_classification"
:
return
PoolerClassify
()
if
problem_type
==
"multi_label_classification"
:
return
PoolerMultiLabelClassify
()
return
PoolerClassify
()
def
get_cross_encoder_activation_function
(
config
:
PretrainedConfig
):
function_name
:
str
|
None
=
None
if
(
hasattr
(
config
,
"sentence_transformers"
)
and
"activation_fn"
in
config
.
sentence_transformers
):
function_name
=
config
.
sentence_transformers
[
"activation_fn"
]
elif
(
hasattr
(
config
,
"sbert_ce_default_activation_function"
)
and
config
.
sbert_ce_default_activation_function
is
not
None
):
function_name
=
config
.
sbert_ce_default_activation_function
if
function_name
is
not
None
:
assert
function_name
.
startswith
(
"torch.nn.modules."
),
(
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
fn
=
resolve_obj_by_qualname
(
function_name
)()
return
PoolerActivation
.
wraps
(
fn
)
return
PoolerClassify
()
class
PoolingMethod
(
nn
.
Module
,
ABC
):
@
staticmethod
def
from_pooling_type
(
pooling_type
:
PoolingTypeStr
)
->
"PoolingMethod"
:
if
pooling_type
==
"LAST"
:
return
LastPool
()
if
pooling_type
==
"ALL"
:
return
AllPool
()
if
pooling_type
==
"CLS"
:
return
CLSPool
()
if
pooling_type
==
"MEAN"
:
return
MeanPool
()
if
pooling_type
==
"STEP"
:
raise
ValueError
(
"'STEP' pooling is handled by StepPooler "
"and is not a standalone PoolingMethod."
)
raise
NotImplementedError
(
f
"Unsupported method:
{
pooling_type
!
r
}
"
)
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
raise
NotImplementedError
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolingMethodOutput
:
raise
NotImplementedError
class
CLSPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with CLS pooling"
)
return
hidden_states
[
pooling_cursor
.
first_token_indices_gpu
]
class
LastPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
class
AllPool
(
PoolingMethod
):
def
__init__
(
self
):
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
self
.
enable_chunked_prefill
=
(
vllm_config
.
scheduler_config
.
enable_chunked_prefill
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenwisePoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
hidden_states_all
=
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
()
)
hidden_states_lst
=
[
hidden_states_all
[
i
]
for
i
in
pooling_cursor
.
index
]
if
not
self
.
enable_chunked_prefill
:
return
hidden_states_lst
pooling_states
=
pooling_metadata
.
pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for
p
,
hs_chunk
in
zip
(
pooling_states
,
hidden_states_lst
):
p
.
hidden_states_cache
.
append
(
hs_chunk
)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list
=
list
[
torch
.
Tensor
|
None
]()
for
p
,
finished
in
zip
(
pooling_states
,
pooling_cursor
.
is_finished
()):
if
finished
:
hidden_states_cache
=
p
.
hidden_states_cache
if
len
(
hidden_states_cache
)
==
1
:
output_list
.
append
(
hidden_states_cache
[
0
])
else
:
output_list
.
append
(
torch
.
concat
(
hidden_states_cache
,
dim
=
0
))
p
.
clean
()
else
:
output_list
.
append
(
None
)
return
output_list
class
MeanPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with MEAN pooling"
)
prompt_lens
=
pooling_cursor
.
prompt_lens_cpu
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
start_indices
=
pooling_cursor
.
first_token_indices_gpu
end_indices
=
pooling_cursor
.
last_token_indices_gpu
return
(
cumsum
[
end_indices
]
-
cumsum
[
start_indices
]
+
hidden_states
[
start_indices
]
)
/
prompt_lens
.
unsqueeze
(
1
)
_T
=
TypeVar
(
"_T"
,
torch
.
Tensor
,
list
[
torch
.
Tensor
])
class
BasePoolerActivation
(
nn
.
Module
,
ABC
):
@
abstractmethod
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
raise
NotImplementedError
class
PoolerActivation
(
BasePoolerActivation
):
@
staticmethod
def
wraps
(
module
:
nn
.
Module
):
if
isinstance
(
module
,
nn
.
Identity
):
return
PoolerIdentity
()
if
isinstance
(
module
,
(
nn
.
Sigmoid
,
nn
.
Softmax
)):
return
PoolerClassify
()
return
LambdaPoolerActivation
(
module
)
@
abstractmethod
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
if
isinstance
(
pooled_data
,
list
):
return
[
self
.
forward_chunk
(
data
)
for
data
in
pooled_data
]
return
self
.
forward_chunk
(
pooled_data
)
class
PoolerIdentity
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
pooled_data
class
PoolerNormalize
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=-
1
)
class
PoolerMultiLabelClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
sigmoid
(
pooled_data
)
class
PoolerClassify
(
PoolerActivation
):
def
__init__
(
self
,
*
,
static_num_labels
:
bool
=
True
)
->
None
:
super
().
__init__
()
if
static_num_labels
:
vllm_config
=
get_current_vllm_config
()
self
.
num_labels
=
getattr
(
vllm_config
.
model_config
.
hf_config
,
"num_labels"
,
0
)
if
self
.
num_labels
==
0
:
logger
.
warning
(
"num_labels should be > 0 for classification"
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
else
:
self
.
num_labels
=
None
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_labels
=
(
self
.
num_labels
if
self
.
num_labels
is
not
None
else
pooled_data
.
shape
[
-
1
]
)
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
)
return
F
.
softmax
(
pooled_data
,
dim
=-
1
)
class
LambdaPoolerActivation
(
PoolerActivation
):
def
__init__
(
self
,
fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]):
super
().
__init__
()
self
.
fn
=
fn
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
fn
(
pooled_data
)
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
@
staticmethod
def
for_token_embed
(
pooler_config
:
PoolerConfig
):
head
=
TokenEmbeddingPoolerHead
()
if
pooler_config
.
pooling_type
==
"STEP"
:
return
StepPooler
(
head
=
head
)
return
AllPooler
(
head
=
head
)
@
staticmethod
def
for_token_classify
(
pooler_config
:
PoolerConfig
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
if
pooler_config
.
pooling_type
==
"STEP"
:
return
StepPooler
(
head
=
head
)
return
AllPooler
(
head
=
head
)
@
staticmethod
def
for_embed
(
pooler_config
:
PoolerConfig
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"embed"
,
pooler_config
=
pooler_config
,
)
pooling
=
PoolingMethod
.
from_pooling_type
(
resolved_config
.
pooling_type
)
head
=
EmbeddingPoolerHead
()
return
SimplePooler
(
pooling
=
pooling
,
head
=
head
)
@
staticmethod
def
for_classify
(
pooler_config
:
PoolerConfig
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"classify"
,
pooler_config
=
pooler_config
,
)
pooling
=
PoolingMethod
.
from_pooling_type
(
resolved_config
.
pooling_type
)
return
ClassifierPooler
(
pooling
=
pooling
,
classifier
=
classifier
,
act_fn
=
act_fn
,
)
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
"""Determine which pooling tasks are supported."""
raise
NotImplementedError
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
class
DummyPooler
(
Pooler
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"plugin"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
hidden_states
class
TokenPoolerHead
(
nn
.
Module
,
ABC
):
"""Applicable to pooling strategies that output one token."""
@
abstractmethod
def
forward
(
self
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerHeadOutput
:
raise
NotImplementedError
class
EmbeddingPoolerHead
(
TokenPoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
forward
(
self
,
pooled_data
:
TokenPoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
pooling_params
=
pooling_metadata
.
pooling_params
# for matryoshka representation
dimensions_list
=
[
pooling_param
.
dimensions
for
pooling_param
in
pooling_params
]
if
any
(
d
is
not
None
for
d
in
dimensions_list
):
# change the output dimension
assert
len
(
pooled_data
)
==
len
(
dimensions_list
)
if
len
(
set
(
dimensions_list
))
==
1
and
not
isinstance
(
pooled_data
,
list
):
# if all dimensions are the same
d
=
dimensions_list
[
0
]
pooled_data
=
pooled_data
[...,
:
d
]
else
:
pooled_data
=
[
vecs
if
d
is
None
else
vecs
[...,
:
d
]
for
vecs
,
d
in
zip
(
pooled_data
,
dimensions_list
)
]
# for normalize
flags
=
[
p
.
normalize
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
if
flags
[
0
]:
pooled_data
=
self
.
activation
(
pooled_data
)
else
:
pooled_data
=
[
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# pooled_data shape: [batchsize, embedding_dimension]
return
pooled_data
class
SimplePooler
(
Pooler
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
"""
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
TokenPoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerHeadOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
class
ClassifierPooler
(
Pooler
):
"""A pooling layer for classification tasks.
This layer does the following:
1. Applies a classification layer to the hidden states.
2. Optionally applies a pooler layer.
3. Applies an activation function to the output.
"""
@
staticmethod
def
act_fn_for_seq_cls
(
model_config
:
ModelConfig
):
return
get_classification_activation_function
(
model_config
.
hf_config
)
@
staticmethod
def
act_fn_for_cross_encoder
(
model_config
:
ModelConfig
):
return
get_cross_encoder_activation_function
(
model_config
.
hf_config
)
@
staticmethod
def
resolve_act_fn
(
model_config
:
ModelConfig
,
static_num_labels
:
bool
=
True
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
isinstance
(
act_fn
,
str
):
if
act_fn
==
"classify"
:
return
ClassifierPooler
.
act_fn_for_seq_cls
(
model_config
)
elif
act_fn
==
"score"
:
return
ClassifierPooler
.
act_fn_for_cross_encoder
(
model_config
)
else
:
raise
ValueError
(
f
"act_fn [
{
act_fn
=
}
] not supported."
)
elif
act_fn
is
None
:
return
PoolerClassify
(
static_num_labels
=
static_num_labels
)
else
:
assert
callable
(
act_fn
)
return
act_fn
def
__init__
(
self
,
pooling
:
PoolingFn
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
self
.
pooling
=
pooling
self
.
classifier
=
classifier
self
.
act_fn
=
self
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
)
self
.
logit_bias
:
float
|
None
=
(
vllm_config
.
model_config
.
pooler_config
.
logit_bias
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
# pooled_data shape: [batchsize, num_labels]
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
pooling_params
=
pooling_metadata
.
pooling_params
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
scores
=
self
.
act_fn
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
else
:
scores
=
[
self
.
act_fn
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# scores shape: [batchsize, num_labels]
return
scores
class
TokenwisePoolerHead
(
nn
.
Module
,
ABC
):
"""Applicable to pooling strategies that output multiple tokens."""
@
abstractmethod
def
forward
(
self
,
pooled_data
:
TokenwisePoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenwisePoolerHeadOutput
:
raise
NotImplementedError
class
TokenEmbeddingPoolerHead
(
TokenwisePoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
self
.
projector
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
forward
(
self
,
pooled_data
:
TokenwisePoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenwisePoolerHeadOutput
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data
=
pooled_data
[...,
:
pooling_param
.
dimensions
]
# for normalize
if
pooling_param
.
normalize
:
pooled_data
=
self
.
activation
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
return
pooled_data
class
TokenClassifierPoolerHead
(
TokenwisePoolerHead
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
(
vllm_config
.
model_config
.
pooler_config
.
logit_bias
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
self
.
activation
=
ClassifierPooler
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
def
forward
(
self
,
pooled_data
:
TokenwisePoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenwisePoolerHeadOutput
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
if
self
.
classifier
is
not
None
:
scores
=
self
.
classifier
(
pooled_data
)
else
:
scores
=
pooled_data
# scores shape: [n_token, num_labels]
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
if
pooling_param
.
use_activation
:
scores
=
self
.
activation
(
scores
)
# scores shape: [n_token, num_labels]
return
scores
class
AllPooler
(
Pooler
):
def
__init__
(
self
,
head
:
TokenwisePoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenwisePoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
return
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
class
StepPooler
(
Pooler
):
def
__init__
(
self
,
head
:
TokenwisePoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
head
=
head
def
extract_states
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
|
None
]:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
pooling_metadata
.
get_prompt_token_ids
()
pooling_params
=
pooling_metadata
.
pooling_params
pooled_data
=
list
[
torch
.
Tensor
|
None
]()
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
):
# for unfinished chunked prefill
if
data
is
None
:
pooled_data
.
append
(
data
)
continue
step_tag_id
=
pooling_param
.
step_tag_id
returned_token_ids
=
pooling_param
.
returned_token_ids
if
returned_token_ids
is
not
None
and
len
(
returned_token_ids
)
>
0
:
data
=
data
[:,
returned_token_ids
]
if
step_tag_id
is
not
None
:
data
=
data
[
token_id
==
step_tag_id
]
pooled_data
.
append
(
data
)
return
pooled_data
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenwisePoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
return
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
class
DispatchPooler
(
Pooler
):
"""Dispatches calls to a sub-pooler based on the pooling task."""
def
__init__
(
self
,
poolers_by_task
:
Mapping
[
PoolingTask
,
Pooler
])
->
None
:
super
().
__init__
()
for
task
,
pooler
in
poolers_by_task
.
items
():
if
task
not
in
pooler
.
get_supported_tasks
():
raise
ValueError
(
f
"
{
pooler
=
}
does not support
{
task
=
}
. "
f
"Supported tasks:
{
pooler
.
get_supported_tasks
()
}
"
)
self
.
poolers_by_task
=
poolers_by_task
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
set
(
self
.
poolers_by_task
)
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
poolers_by_task
[
task
].
get_pooling_updates
(
task
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
outputs
=
list
[
torch
.
Tensor
|
None
]()
offset
=
0
for
task
,
group
in
groupby
(
pooling_metadata
.
tasks
):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
raise
ValueError
(
f
"Unsupported task:
{
task
}
"
f
"Supported tasks:
{
self
.
get_supported_tasks
()
}
"
)
num_items
=
len
(
list
(
group
))
group_output
:
PoolerOutput
=
pooler
(
hidden_states
,
pooling_metadata
[
offset
:
offset
+
num_items
],
)
outputs
.
extend
(
group_output
)
offset
+=
num_items
return
outputs
def
extra_repr
(
self
)
->
str
:
s
=
f
"supported_task=
{
self
.
get_supported_tasks
()
}
"
return
s
vllm/model_executor/layers/pooler/__init__.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.abstract
import
*
from
.common
import
*
from
.special
import
*
vllm/model_executor/layers/pooler/abstract.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
import
torch
import
torch.nn
as
nn
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
PoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.common
import
PoolingParamsUpdate
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
"""Determine which pooling tasks are supported."""
raise
NotImplementedError
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
__all__
=
[
"Pooler"
]
vllm/model_executor/layers/pooler/activations.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
from
typing
import
TypeVar
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.utils.import_utils
import
resolve_obj_by_qualname
logger
=
init_logger
(
__name__
)
def
get_classification_act_fn
(
config
:
PretrainedConfig
,
)
->
"PoolerActivation"
:
# Implement alignment with transformers ForSequenceClassificationLoss
# https://github.com/huggingface/transformers/blob/57bb6db6ee4cfaccc45b8d474dfad5a17811ca60/src/transformers/loss/loss_utils.py#L92
problem_type
=
getattr
(
config
,
"problem_type"
,
""
)
if
problem_type
==
"regression"
:
return
PoolerIdentity
()
if
problem_type
==
"single_label_classification"
:
return
PoolerClassify
()
if
problem_type
==
"multi_label_classification"
:
return
PoolerMultiLabelClassify
()
return
PoolerClassify
()
def
get_cross_encoder_act_fn
(
config
:
PretrainedConfig
,
)
->
"PoolerActivation"
:
function_name
:
str
|
None
=
None
if
(
hasattr
(
config
,
"sentence_transformers"
)
and
"activation_fn"
in
config
.
sentence_transformers
):
function_name
=
config
.
sentence_transformers
[
"activation_fn"
]
elif
(
hasattr
(
config
,
"sbert_ce_default_activation_function"
)
and
config
.
sbert_ce_default_activation_function
is
not
None
):
function_name
=
config
.
sbert_ce_default_activation_function
if
function_name
is
not
None
:
assert
function_name
.
startswith
(
"torch.nn.modules."
),
(
"Loading of activation functions is restricted to "
"torch.nn.modules for security reasons"
)
fn
=
resolve_obj_by_qualname
(
function_name
)()
return
PoolerActivation
.
wraps
(
fn
)
return
PoolerClassify
()
def
resolve_classifier_act_fn
(
model_config
:
ModelConfig
,
static_num_labels
:
bool
=
True
,
act_fn
:
"PoolerActivation | str | None"
=
None
,
):
if
isinstance
(
act_fn
,
str
):
if
act_fn
==
"classify"
:
return
get_classification_act_fn
(
model_config
.
hf_config
)
if
act_fn
==
"score"
:
return
get_cross_encoder_act_fn
(
model_config
.
hf_config
)
raise
ValueError
(
f
"act_fn [
{
act_fn
=
}
] not supported."
)
if
act_fn
is
None
:
return
PoolerClassify
(
static_num_labels
=
static_num_labels
)
assert
callable
(
act_fn
)
return
act_fn
_T
=
TypeVar
(
"_T"
,
torch
.
Tensor
,
list
[
torch
.
Tensor
])
class
PoolerActivation
(
nn
.
Module
,
ABC
):
@
staticmethod
def
wraps
(
module
:
nn
.
Module
):
if
isinstance
(
module
,
nn
.
Identity
):
return
PoolerIdentity
()
if
isinstance
(
module
,
(
nn
.
Sigmoid
,
nn
.
Softmax
)):
return
PoolerClassify
()
return
LambdaPoolerActivation
(
module
)
@
abstractmethod
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
forward
(
self
,
pooled_data
:
_T
)
->
_T
:
# shape:
# classify (& score) -> (batch_size, num_classes)
# embed -> (batch_size, embedding_dim) or list(embedding_dim)
# (batch_size, dimensions) or list(dimensions) if using MRL
if
isinstance
(
pooled_data
,
list
):
return
[
self
.
forward_chunk
(
data
)
for
data
in
pooled_data
]
return
self
.
forward_chunk
(
pooled_data
)
class
PoolerIdentity
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
pooled_data
class
PoolerNormalize
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
normalize
(
pooled_data
,
p
=
2
,
dim
=-
1
)
class
PoolerMultiLabelClassify
(
PoolerActivation
):
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
sigmoid
(
pooled_data
)
class
PoolerClassify
(
PoolerActivation
):
def
__init__
(
self
,
*
,
static_num_labels
:
bool
=
True
)
->
None
:
super
().
__init__
()
if
static_num_labels
:
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
num_labels
=
getattr
(
model_config
.
hf_config
,
"num_labels"
,
0
)
else
:
num_labels
=
None
if
num_labels
==
0
:
logger
.
warning
(
"num_labels should be > 0 for classification "
"models, falling back to softmax. "
"Please check if the configuration is correct."
)
self
.
num_labels
=
num_labels
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_labels
=
self
.
num_labels
if
num_labels
is
None
:
num_labels
=
pooled_data
.
shape
[
-
1
]
if
num_labels
<
2
:
return
F
.
sigmoid
(
pooled_data
)
return
F
.
softmax
(
pooled_data
,
dim
=-
1
)
class
LambdaPoolerActivation
(
PoolerActivation
):
def
__init__
(
self
,
fn
:
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]):
super
().
__init__
()
self
.
fn
=
fn
def
forward_chunk
(
self
,
pooled_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
fn
(
pooled_data
)
vllm/model_executor/layers/pooler/common.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
from
dataclasses
import
dataclass
import
torch
from
vllm.pooling_params
import
PoolingParams
ClassifierFn
=
Callable
[[
torch
.
Tensor
],
torch
.
Tensor
]
@
dataclass
(
frozen
=
True
)
class
PoolingParamsUpdate
:
requires_token_ids
:
bool
=
False
"""Set this flag to enable `get_prompt_token_ids` for your pooler."""
def
__or__
(
self
,
other
:
"PoolingParamsUpdate"
)
->
"PoolingParamsUpdate"
:
return
PoolingParamsUpdate
(
requires_token_ids
=
self
.
requires_token_ids
or
other
.
requires_token_ids
,
)
def
apply
(
self
,
params
:
PoolingParams
)
->
None
:
params
.
requires_token_ids
=
self
.
requires_token_ids
__all__
=
[
"ClassifierFn"
,
"PoolingParamsUpdate"
]
vllm/model_executor/layers/pooler/seqwise/__init__.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Poolers that produce an output aggregating all tokens in the sequence."""
from
.heads
import
(
ClassifierPoolerHead
,
EmbeddingPoolerHead
,
SequencePoolerHead
,
SequencePoolerHeadOutput
,
)
from
.methods
import
(
CLSPool
,
LastPool
,
MeanPool
,
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
from
.poolers
import
(
SequencePooler
,
SequencePoolerOutput
,
SequencePoolingFn
,
SequencePoolingHeadFn
,
pooler_for_classify
,
pooler_for_embed
,
)
__all__
=
[
"SequencePoolerHead"
,
"SequencePoolerHeadOutput"
,
"ClassifierPoolerHead"
,
"EmbeddingPoolerHead"
,
"SequencePoolingMethod"
,
"SequencePoolingMethodOutput"
,
"CLSPool"
,
"LastPool"
,
"MeanPool"
,
"get_seq_pooling_method"
,
"SequencePooler"
,
"SequencePoolingFn"
,
"SequencePoolingHeadFn"
,
"SequencePoolerOutput"
,
"pooler_for_classify"
,
"pooler_for_embed"
,
]
vllm/model_executor/layers/pooler/seqwise/heads.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
typing
import
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.methods
import
SequencePoolingMethodOutput
SequencePoolerHeadOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
class
SequencePoolerHead
(
nn
.
Module
,
ABC
):
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
raise
NotImplementedError
@
abstractmethod
def
forward
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
raise
NotImplementedError
class
EmbeddingPoolerHead
(
SequencePoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
projector
=
_load_st_projector
(
model_config
)
self
.
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"embed"
}
def
forward
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_dimension]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [batchsize, embedding_dimension]
# for matryoshka representation
dimensions_list
=
[
pooling_param
.
dimensions
for
pooling_param
in
pooling_params
]
if
any
(
d
is
not
None
for
d
in
dimensions_list
):
# change the output dimension
assert
len
(
pooled_data
)
==
len
(
dimensions_list
)
if
len
(
set
(
dimensions_list
))
==
1
and
not
isinstance
(
pooled_data
,
list
):
# if all dimensions are the same
d
=
dimensions_list
[
0
]
pooled_data
=
pooled_data
[...,
:
d
]
else
:
pooled_data
=
[
vecs
if
d
is
None
else
vecs
[...,
:
d
]
for
vecs
,
d
in
zip
(
pooled_data
,
dimensions_list
)
]
# for normalize
flags
=
[
p
.
normalize
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
if
flags
[
0
]:
pooled_data
=
self
.
activation
(
pooled_data
)
else
:
pooled_data
=
[
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# pooled_data shape: [batchsize, embedding_dimension]
return
pooled_data
class
ClassifierPoolerHead
(
SequencePoolerHead
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
model_config
.
head_dtype
self
.
act_fn
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"classify"
,
"score"
}
def
forward
(
self
,
pooled_data
:
SequencePoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerHeadOutput
:
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
# pooled_data shape: [batchsize, hidden_size]
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
if
self
.
classifier
is
not
None
:
pooled_data
=
self
.
classifier
(
pooled_data
)
# pooled_data shape: [batchsize, num_labels]
if
self
.
logit_bias
is
not
None
:
pooled_data
-=
self
.
logit_bias
flags
=
[
p
.
use_activation
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
scores
=
self
.
act_fn
(
pooled_data
)
if
flags
[
0
]
else
pooled_data
else
:
scores
=
[
self
.
act_fn
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
# scores shape: [batchsize, num_labels]
return
scores
vllm/model_executor/layers/pooler/seqwise/methods.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
typing
import
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config.pooler
import
PoolingTypeStr
from
vllm.model_executor.layers.pooler
import
PoolingParamsUpdate
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
SequencePoolingMethodOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
class
SequencePoolingMethod
(
nn
.
Module
,
ABC
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
,
"embed"
,
"classify"
,
"score"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolingMethodOutput
:
raise
NotImplementedError
class
CLSPool
(
SequencePoolingMethod
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with CLS pooling"
)
return
hidden_states
[
pooling_cursor
.
first_token_indices_gpu
]
class
LastPool
(
SequencePoolingMethod
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
return
hidden_states
[
pooling_cursor
.
last_token_indices_gpu
]
class
MeanPool
(
SequencePoolingMethod
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolingMethodOutput
:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
assert
not
pooling_cursor
.
is_partial_prefill
(),
(
"partial prefill not supported with MEAN pooling"
)
prompt_lens
=
pooling_cursor
.
prompt_lens_cpu
.
to
(
hidden_states
.
device
,
non_blocking
=
True
)
# Use float32 for torch.cumsum in MeanPool,
# otherwise precision will be lost significantly.
cumsum
=
torch
.
cumsum
(
hidden_states
,
dim
=
0
,
dtype
=
torch
.
float32
)
start_indices
=
pooling_cursor
.
first_token_indices_gpu
end_indices
=
pooling_cursor
.
last_token_indices_gpu
return
(
cumsum
[
end_indices
]
-
cumsum
[
start_indices
]
+
hidden_states
[
start_indices
]
)
/
prompt_lens
.
unsqueeze
(
1
)
def
get_seq_pooling_method
(
pooling_type
:
PoolingTypeStr
|
str
):
if
pooling_type
==
"LAST"
:
return
LastPool
()
if
pooling_type
==
"CLS"
:
return
CLSPool
()
if
pooling_type
==
"MEAN"
:
return
MeanPool
()
raise
NotImplementedError
(
f
"Unknown sequence pooling type:
{
pooling_type
!
r
}
"
)
vllm/model_executor/layers/pooler/seqwise/poolers.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
,
Set
from
typing
import
TypeAlias
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.activations
import
PoolerActivation
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.heads
import
(
ClassifierPoolerHead
,
EmbeddingPoolerHead
,
SequencePoolerHead
,
SequencePoolerHeadOutput
,
)
from
.methods
import
(
SequencePoolingMethod
,
SequencePoolingMethodOutput
,
get_seq_pooling_method
,
)
SequencePoolingFn
:
TypeAlias
=
Callable
[
[
torch
.
Tensor
,
PoolingMetadata
],
SequencePoolingMethodOutput
,
]
SequencePoolingHeadFn
:
TypeAlias
=
Callable
[
[
SequencePoolingMethodOutput
,
PoolingMetadata
],
SequencePoolerHeadOutput
,
]
SequencePoolerOutput
:
TypeAlias
=
torch
.
Tensor
|
list
[
torch
.
Tensor
]
class
SequencePooler
(
Pooler
):
"""
A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Postprocesses the output based on pooling head.
3. Returns structured results as `PoolerOutput`.
"""
def
__init__
(
self
,
pooling
:
SequencePoolingMethod
|
SequencePoolingFn
,
head
:
SequencePoolerHead
|
SequencePoolingHeadFn
,
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
tasks
=
set
(
POOLING_TASKS
)
if
isinstance
(
self
.
pooling
,
SequencePoolingMethod
):
tasks
&=
self
.
pooling
.
get_supported_tasks
()
if
isinstance
(
self
.
head
,
SequencePoolerHead
):
tasks
&=
self
.
head
.
get_supported_tasks
()
return
tasks
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
updates
=
PoolingParamsUpdate
()
if
isinstance
(
self
.
pooling
,
SequencePoolingMethod
):
updates
|=
self
.
pooling
.
get_pooling_updates
(
task
)
return
updates
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
SequencePoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
def
pooler_for_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_pooling_type
())
head
=
EmbeddingPoolerHead
()
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
def
pooler_for_classify
(
pooler_config
:
PoolerConfig
,
*
,
pooling
:
SequencePoolingMethod
|
SequencePoolingFn
|
None
=
None
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
pooling
is
None
:
pooling
=
get_seq_pooling_method
(
pooler_config
.
get_pooling_type
())
head
=
ClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
return
SequencePooler
(
pooling
=
pooling
,
head
=
head
)
vllm/model_executor/layers/pooler/special.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Mapping
,
Set
from
itertools
import
groupby
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.model_executor.layers.pooler
import
PoolingParamsUpdate
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.abstract
import
Pooler
,
PoolerOutput
from
.common
import
ClassifierFn
from
.seqwise
import
(
SequencePoolingFn
,
SequencePoolingMethod
,
pooler_for_classify
,
pooler_for_embed
,
)
from
.tokwise
import
AllPool
,
pooler_for_token_classify
,
pooler_for_token_embed
class
DispatchPooler
(
Pooler
):
"""Dispatches calls to a sub-pooler based on the pooling task."""
@
classmethod
def
for_embedding
(
cls
,
pooler_config
:
PoolerConfig
):
return
cls
(
{
"token_embed"
:
pooler_for_token_embed
(
pooler_config
),
"embed"
:
pooler_for_embed
(
pooler_config
),
},
)
@
classmethod
def
for_seq_cls
(
cls
,
pooler_config
:
PoolerConfig
,
*
,
pooling
:
SequencePoolingMethod
|
SequencePoolingFn
|
None
=
None
,
classifier
:
ClassifierFn
|
None
=
None
,
):
return
cls
(
{
"token_classify"
:
pooler_for_token_classify
(
pooler_config
,
pooling
=
AllPool
(),
classifier
=
classifier
,
),
"classify"
:
pooler_for_classify
(
pooler_config
,
pooling
=
pooling
,
classifier
=
classifier
,
act_fn
=
"classify"
,
),
"score"
:
pooler_for_classify
(
pooler_config
,
pooling
=
pooling
,
classifier
=
classifier
,
act_fn
=
"score"
,
),
}
)
def
__init__
(
self
,
poolers_by_task
:
Mapping
[
PoolingTask
,
Pooler
])
->
None
:
super
().
__init__
()
for
task
,
pooler
in
poolers_by_task
.
items
():
if
task
not
in
pooler
.
get_supported_tasks
():
raise
ValueError
(
f
"
{
pooler
=
}
does not support
{
task
=
}
. "
f
"Supported tasks:
{
pooler
.
get_supported_tasks
()
}
"
)
self
.
poolers_by_task
=
poolers_by_task
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
set
(
self
.
poolers_by_task
)
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
poolers_by_task
[
task
].
get_pooling_updates
(
task
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
poolers_by_task
=
self
.
poolers_by_task
outputs
=
list
[
torch
.
Tensor
|
None
]()
offset
=
0
for
task
,
group
in
groupby
(
pooling_metadata
.
tasks
):
if
not
(
pooler
:
=
poolers_by_task
.
get
(
task
)):
raise
ValueError
(
f
"Unsupported task:
{
task
!
r
}
"
f
"Supported tasks:
{
self
.
get_supported_tasks
()
}
"
)
num_items
=
len
(
list
(
group
))
group_output
:
PoolerOutput
=
pooler
(
hidden_states
,
pooling_metadata
[
offset
:
offset
+
num_items
],
)
outputs
.
extend
(
group_output
)
offset
+=
num_items
return
outputs
def
extra_repr
(
self
)
->
str
:
s
=
f
"supported_task=
{
self
.
get_supported_tasks
()
}
"
return
s
class
IdentityPooler
(
Pooler
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"plugin"
,
"score"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
return
hidden_states
__all__
=
[
"DispatchPooler"
,
"IdentityPooler"
]
vllm/model_executor/layers/pooler/tokwise/__init__.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Poolers that produce an output for each token in the sequence."""
from
.heads
import
(
TokenClassifierPoolerHead
,
TokenEmbeddingPoolerHead
,
TokenPoolerHead
,
TokenPoolerHeadOutputItem
,
)
from
.methods
import
(
AllPool
,
StepPool
,
TokenPoolingMethod
,
TokenPoolingMethodOutputItem
,
get_tok_pooling_method
,
)
from
.poolers
import
(
TokenPooler
,
TokenPoolerOutput
,
pooler_for_token_classify
,
pooler_for_token_embed
,
)
__all__
=
[
"TokenPoolerHead"
,
"TokenPoolerHeadOutputItem"
,
"TokenClassifierPoolerHead"
,
"TokenEmbeddingPoolerHead"
,
"TokenPoolingMethod"
,
"TokenPoolingMethodOutputItem"
,
"AllPool"
,
"StepPool"
,
"get_tok_pooling_method"
,
"TokenPooler"
,
"TokenPoolerOutput"
,
"pooler_for_token_classify"
,
"pooler_for_token_embed"
,
]
vllm/model_executor/layers/pooler/tokwise/heads.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
typing
import
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.model_executor.layers.pooler
import
ClassifierFn
from
vllm.model_executor.layers.pooler.activations
import
(
PoolerActivation
,
PoolerNormalize
,
resolve_classifier_act_fn
,
)
from
vllm.model_executor.models.adapters
import
_load_st_projector
from
vllm.pooling_params
import
PoolingParams
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.methods
import
TokenPoolingMethodOutputItem
TokenPoolerHeadOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
class
TokenPoolerHead
(
nn
.
Module
,
ABC
):
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
raise
NotImplementedError
@
abstractmethod
def
forward_chunk
(
self
,
pooled_data
:
TokenPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenPoolerHeadOutputItem
:
raise
NotImplementedError
def
forward
(
self
,
pooled_data
:
list
[
TokenPoolingMethodOutputItem
],
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
TokenPoolerHeadOutputItem
]:
pooling_params
=
pooling_metadata
.
pooling_params
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
return
[
self
.
forward_chunk
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
class
TokenEmbeddingPoolerHead
(
TokenPoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
projector
=
_load_st_projector
(
model_config
)
self
.
head_dtype
=
model_config
.
head_dtype
self
.
activation
=
PoolerNormalize
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
}
def
forward_chunk
(
self
,
pooled_data
:
TokenPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenPoolerHeadOutputItem
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data
=
pooled_data
[...,
:
pooling_param
.
dimensions
]
# for normalize
if
pooling_param
.
normalize
:
pooled_data
=
self
.
activation
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
return
pooled_data
class
TokenClassifierPoolerHead
(
TokenPoolerHead
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
model_config
=
vllm_config
.
model_config
self
.
classifier
=
classifier
self
.
logit_bias
:
float
|
None
=
model_config
.
pooler_config
.
logit_bias
self
.
head_dtype
=
model_config
.
head_dtype
self
.
act_fn
=
resolve_classifier_act_fn
(
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_classify"
}
def
forward_chunk
(
self
,
pooled_data
:
TokenPoolingMethodOutputItem
,
pooling_param
:
PoolingParams
,
)
->
TokenPoolerHeadOutputItem
:
# for unfinished chunked prefill
if
pooled_data
is
None
:
return
None
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
if
self
.
classifier
is
not
None
:
scores
=
self
.
classifier
(
pooled_data
)
else
:
scores
=
pooled_data
# scores shape: [n_token, num_labels]
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
if
pooling_param
.
use_activation
:
scores
=
self
.
act_fn
(
scores
)
# scores shape: [n_token, num_labels]
return
scores
vllm/model_executor/layers/pooler/tokwise/methods.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Set
from
typing
import
TypeAlias
import
torch
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.config.pooler
import
PoolingTypeStr
from
vllm.model_executor.layers.pooler
import
PoolingParamsUpdate
from
vllm.tasks
import
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
TokenPoolingMethodOutputItem
:
TypeAlias
=
torch
.
Tensor
|
None
class
TokenPoolingMethod
(
nn
.
Module
,
ABC
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
TokenPoolingMethodOutputItem
]:
raise
NotImplementedError
class
AllPool
(
TokenPoolingMethod
):
def
__init__
(
self
):
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
scheduler_config
=
vllm_config
.
scheduler_config
self
.
enable_chunked_prefill
=
scheduler_config
.
enable_chunked_prefill
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
TokenPoolingMethodOutputItem
]:
pooling_cursor
=
pooling_metadata
.
get_pooling_cursor
()
hidden_states_all
=
hidden_states
.
split
(
pooling_cursor
.
num_scheduled_tokens_cpu
.
tolist
()
)
hidden_states_lst
=
[
hidden_states_all
[
i
]
for
i
in
pooling_cursor
.
index
]
if
not
self
.
enable_chunked_prefill
:
return
hidden_states_lst
pooling_states
=
pooling_metadata
.
pooling_states
# If chunked_prefill is enabled
# 1. first store the chunked hidden_states in pooling_states.hidden_states_cache
for
p
,
hs_chunk
in
zip
(
pooling_states
,
hidden_states_lst
):
p
.
hidden_states_cache
.
append
(
hs_chunk
)
# 2. Once prefill is finished, send hidden_states_cache to PoolerHead
output_list
=
list
[
TokenPoolingMethodOutputItem
]()
for
p
,
finished
in
zip
(
pooling_states
,
pooling_cursor
.
is_finished
()):
if
finished
:
hidden_states_cache
=
p
.
hidden_states_cache
if
len
(
hidden_states_cache
)
==
1
:
output_list
.
append
(
hidden_states_cache
[
0
])
else
:
output_list
.
append
(
torch
.
concat
(
hidden_states_cache
,
dim
=
0
))
p
.
clean
()
else
:
output_list
.
append
(
None
)
return
output_list
class
StepPool
(
AllPool
):
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
TokenPoolingMethodOutputItem
]:
pooled_data_lst
=
super
().
forward
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
pooling_metadata
.
get_prompt_token_ids
()
pooling_params
=
pooling_metadata
.
pooling_params
pooled_data
=
list
[
torch
.
Tensor
|
None
]()
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
):
# for unfinished chunked prefill
if
data
is
None
:
pass
else
:
step_tag_id
=
pooling_param
.
step_tag_id
returned_token_ids
=
pooling_param
.
returned_token_ids
if
returned_token_ids
is
not
None
and
len
(
returned_token_ids
)
>
0
:
data
=
data
[:,
returned_token_ids
]
if
step_tag_id
is
not
None
:
data
=
data
[
token_id
==
step_tag_id
]
pooled_data
.
append
(
data
)
return
pooled_data
def
get_tok_pooling_method
(
pooling_type
:
PoolingTypeStr
|
str
):
if
pooling_type
==
"ALL"
:
return
AllPool
()
if
pooling_type
==
"STEP"
:
return
StepPool
()
# TODO: Separate seq and tok pooling types so we don't need this fallback
return
AllPool
()
raise
NotImplementedError
(
f
"Unknown tokenwise pooling type:
{
pooling_type
!
r
}
"
)
vllm/model_executor/layers/pooler/tokwise/poolers.py
0 → 100644
View file @
c8ed39b9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Callable
,
Set
from
typing
import
TypeAlias
import
torch
from
vllm.config
import
PoolerConfig
from
vllm.model_executor.layers.pooler
import
ClassifierFn
,
PoolingParamsUpdate
from
vllm.model_executor.layers.pooler.abstract
import
Pooler
from
vllm.model_executor.layers.pooler.activations
import
PoolerActivation
from
vllm.tasks
import
POOLING_TASKS
,
PoolingTask
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.heads
import
(
TokenClassifierPoolerHead
,
TokenEmbeddingPoolerHead
,
TokenPoolerHead
,
TokenPoolerHeadOutputItem
,
)
from
.methods
import
(
TokenPoolingMethod
,
TokenPoolingMethodOutputItem
,
get_tok_pooling_method
,
)
TokenPoolingFn
:
TypeAlias
=
Callable
[
[
torch
.
Tensor
,
PoolingMetadata
],
list
[
TokenPoolingMethodOutputItem
],
]
TokenPoolingHeadFn
:
TypeAlias
=
Callable
[
[
list
[
TokenPoolingMethodOutputItem
],
PoolingMetadata
],
list
[
TokenPoolerHeadOutputItem
],
]
TokenPoolerOutput
:
TypeAlias
=
list
[
torch
.
Tensor
|
None
]
class
TokenPooler
(
Pooler
):
"""
A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Postprocesses the output based on pooling head.
3. Returns structured results as `PoolerOutput`.
"""
def
__init__
(
self
,
pooling
:
TokenPoolingMethod
|
TokenPoolingFn
,
head
:
TokenPoolerHead
|
TokenPoolingHeadFn
,
)
->
None
:
super
().
__init__
()
self
.
pooling
=
pooling
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
tasks
=
set
(
POOLING_TASKS
)
if
isinstance
(
self
.
pooling
,
TokenPoolingMethod
):
tasks
&=
self
.
pooling
.
get_supported_tasks
()
if
isinstance
(
self
.
head
,
TokenPoolerHead
):
tasks
&=
self
.
head
.
get_supported_tasks
()
return
tasks
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
updates
=
PoolingParamsUpdate
()
if
isinstance
(
self
.
pooling
,
TokenPoolingMethod
):
updates
|=
self
.
pooling
.
get_pooling_updates
(
task
)
return
updates
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
def
pooler_for_token_embed
(
pooler_config
:
PoolerConfig
):
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_pooling_type
())
head
=
TokenEmbeddingPoolerHead
()
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
def
pooler_for_token_classify
(
pooler_config
:
PoolerConfig
,
*
,
pooling
:
TokenPoolingMethod
|
TokenPoolingFn
|
None
=
None
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
pooling
is
None
:
pooling
=
get_tok_pooling_method
(
pooler_config
.
get_pooling_type
())
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
return
TokenPooler
(
pooling
=
pooling
,
head
=
head
)
vllm/model_executor/models/adapters.py
View file @
c8ed39b9
...
@@ -252,19 +252,14 @@ def as_embedding_model(cls: _T) -> _T:
...
@@ -252,19 +252,14 @@ def as_embedding_model(cls: _T) -> _T:
return
cls
return
cls
# Lazy import
# Lazy import
from
vllm.model_executor.layers.pooler
import
DispatchPooler
,
Pooler
from
vllm.model_executor.layers.pooler
import
DispatchPooler
class
ModelForEmbedding
(
_create_pooling_model_cls
(
cls
)):
class
ModelForEmbedding
(
_create_pooling_model_cls
(
cls
)):
def
_init_pooler
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
def
_init_pooler
(
self
,
vllm_config
:
"VllmConfig"
,
prefix
:
str
=
""
):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_embedding
(
pooler_config
)
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
},
)
ModelForEmbedding
.
__name__
=
_get_pooling_model_name
(
cls
.
__name__
,
"ForEmbedding"
)
ModelForEmbedding
.
__name__
=
_get_pooling_model_name
(
cls
.
__name__
,
"ForEmbedding"
)
...
@@ -289,10 +284,7 @@ def as_seq_cls_model(cls: _T) -> _T:
...
@@ -289,10 +284,7 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
# Lazy import
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
DispatchPooler
DispatchPooler
,
Pooler
,
)
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
...
@@ -318,18 +310,8 @@ def as_seq_cls_model(cls: _T) -> _T:
...
@@ -318,18 +310,8 @@ def as_seq_cls_model(cls: _T) -> _T:
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
)
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]):
...
...
vllm/model_executor/models/bert.py
View file @
c8ed39b9
...
@@ -18,19 +18,25 @@ from vllm.model_executor.layers.linear import (
...
@@ -18,19 +18,25 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
DispatchPooler
,
DispatchPooler
,
Pooler
,
Pooler
,
PoolingMethod
,
PoolingParamsUpdate
,
PoolingParamsUpdate
,
TokenPoolerHeadOutput
,
)
TokenPoolingMethodOutput
,
from
vllm.model_executor.layers.pooler.seqwise
import
(
CLSPool
,
SequencePooler
,
SequencePoolerHeadOutput
,
SequencePoolerOutput
,
SequencePoolingMethodOutput
,
)
from
vllm.model_executor.layers.pooler.tokwise
import
(
pooler_for_token_classify
,
pooler_for_token_embed
,
)
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
from
vllm.v1.outputs
import
TokenPoolerOutput
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
vllm.v1.pool.metadata
import
PoolingMetadata
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
from
.interfaces
import
SupportsCrossEncoding
,
SupportsQuant
...
@@ -85,25 +91,21 @@ class BertEmbedding(nn.Module):
...
@@ -85,25 +91,21 @@ class BertEmbedding(nn.Module):
return
embeddings
return
embeddings
class
BertPooler
(
Pooler
):
class
BertPooler
(
Sequence
Pooler
):
def
__init__
(
self
,
config
:
BertConfig
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
super
().
__init__
(
pooling
=
CLSPool
(),
head
=
self
.
head
,
)
self
.
pooling
=
PoolingMethod
.
from_pooling_type
(
"CLS"
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
activation
=
nn
.
Tanh
()
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
self
.
pooling
.
get_supported_tasks
()
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
self
.
pooling
.
get_pooling_updates
(
task
)
def
head
(
def
head
(
self
,
self
,
pooled_data
:
Token
PoolingMethodOutput
,
pooled_data
:
Sequence
PoolingMethodOutput
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
Token
PoolerHeadOutput
:
)
->
Sequence
PoolerHeadOutput
:
if
isinstance
(
pooled_data
,
list
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
torch
.
stack
(
pooled_data
)
pooled_data
=
torch
.
stack
(
pooled_data
)
...
@@ -111,15 +113,6 @@ class BertPooler(Pooler):
...
@@ -111,15 +113,6 @@ class BertPooler(Pooler):
pooled_data
=
self
.
activation
(
pooled_data
)
pooled_data
=
self
.
activation
(
pooled_data
)
return
pooled_data
return
pooled_data
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
TokenPoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
...
@@ -524,12 +517,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
...
@@ -524,12 +517,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
)
)
def
_build_pooler
(
self
,
pooler_config
:
PoolerConfig
)
->
Pooler
:
def
_build_pooler
(
self
,
pooler_config
:
PoolerConfig
)
->
Pooler
:
return
DispatchPooler
(
return
DispatchPooler
.
for_embedding
(
pooler_config
)
{
"token_embed"
:
Pooler
.
for_token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
)
# Here we encode the token type ids together with the input ids.
# Here we encode the token type ids together with the input ids.
...
@@ -620,6 +608,7 @@ class SPLADESparsePooler(Pooler):
...
@@ -620,6 +608,7 @@ class SPLADESparsePooler(Pooler):
remove_cls_sep
:
bool
=
True
,
remove_cls_sep
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
assert
pooling
in
(
"max"
,
"sum"
)
assert
pooling
in
(
"max"
,
"sum"
)
self
.
mlm_head
=
mlm_head
self
.
mlm_head
=
mlm_head
self
.
cls_token_id
=
cls_token_id
self
.
cls_token_id
=
cls_token_id
...
@@ -637,10 +626,8 @@ class SPLADESparsePooler(Pooler):
...
@@ -637,10 +626,8 @@ class SPLADESparsePooler(Pooler):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
:
)
->
SequencePoolerOutput
:
assert
isinstance
(
hidden_states
,
torch
.
Tensor
)
and
hidden_states
.
dim
()
==
2
lens_tensor
=
pooling_metadata
.
prompt_lens
lens_tensor
:
torch
.
Tensor
=
pooling_metadata
.
prompt_lens
lens
:
list
[
int
]
=
lens_tensor
.
tolist
()
lens
:
list
[
int
]
=
lens_tensor
.
tolist
()
B
:
int
=
len
(
lens
)
B
:
int
=
len
(
lens
)
...
@@ -729,7 +716,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
...
@@ -729,7 +716,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return
DispatchPooler
(
return
DispatchPooler
(
{
{
"token_embed"
:
P
ooler
.
for_token_embed
(
pooler_config
),
"token_embed"
:
p
ooler
_
for_token_embed
(
pooler_config
),
"embed"
:
SPLADESparsePooler
(
"embed"
:
SPLADESparsePooler
(
mlm_head
=
self
.
mlm_head
,
mlm_head
=
self
.
mlm_head
,
cls_token_id
=
cls_id
,
cls_token_id
=
cls_id
,
...
@@ -824,20 +811,10 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
...
@@ -824,20 +811,10 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
.
for_seq_cls
(
{
pooler_config
,
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
pooling
=
self
.
bert
.
pooler
,
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
act_fn
=
"classify"
,
),
"score"
:
ClassifierPooler
(
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
),
}
)
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -891,13 +868,7 @@ class BertForTokenClassification(nn.Module):
...
@@ -891,13 +868,7 @@ class BertForTokenClassification(nn.Module):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
pooler_for_token_classify
(
pooler_config
)
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
=
pooler_config
),
}
)
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
bert
.
embed_input_ids
(
input_ids
)
return
self
.
bert
.
embed_input_ids
(
input_ids
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment