Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f54f8512
Unverified
Commit
f54f8512
authored
Oct 15, 2025
by
wang.yuqi
Committed by
GitHub
Oct 15, 2025
Browse files
[Model][2/N] Improve all pooling task | Support multi-vector retrieval (#25370)
Signed-off-by:
wang.yuqi
<
noooop@126.com
>
parent
d4d1a602
Changes
41
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
396 additions
and
317 deletions
+396
-317
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+12
-9
vllm/entrypoints/openai/protocol.py
vllm/entrypoints/openai/protocol.py
+2
-2
vllm/entrypoints/openai/serving_pooling.py
vllm/entrypoints/openai/serving_pooling.py
+13
-1
vllm/model_executor/layers/pooler.py
vllm/model_executor/layers/pooler.py
+251
-171
vllm/model_executor/models/adapters.py
vllm/model_executor/models/adapters.py
+15
-27
vllm/model_executor/models/bert.py
vllm/model_executor/models/bert.py
+10
-12
vllm/model_executor/models/bert_with_rope.py
vllm/model_executor/models/bert_with_rope.py
+5
-9
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+1
-1
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+9
-2
vllm/model_executor/models/gritlm.py
vllm/model_executor/models/gritlm.py
+1
-1
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+1
-1
vllm/model_executor/models/jamba.py
vllm/model_executor/models/jamba.py
+7
-3
vllm/model_executor/models/jina_vl.py
vllm/model_executor/models/jina_vl.py
+9
-3
vllm/model_executor/models/modernbert.py
vllm/model_executor/models/modernbert.py
+8
-12
vllm/model_executor/models/qwen2_rm.py
vllm/model_executor/models/qwen2_rm.py
+4
-2
vllm/model_executor/models/roberta.py
vllm/model_executor/models/roberta.py
+6
-20
vllm/model_executor/models/terratorch.py
vllm/model_executor/models/terratorch.py
+1
-1
vllm/model_executor/models/transformers_pooling.py
vllm/model_executor/models/transformers_pooling.py
+6
-12
vllm/pooling_params.py
vllm/pooling_params.py
+34
-27
vllm/tasks.py
vllm/tasks.py
+1
-1
No files found.
vllm/entrypoints/openai/api_server.py
View file @
f54f8512
...
@@ -1748,16 +1748,19 @@ async def init_app_state(
...
@@ -1748,16 +1748,19 @@ async def init_app_state(
else
None
else
None
)
)
state
.
openai_serving_pooling
=
(
state
.
openai_serving_pooling
=
(
(
OpenAIServingPooling
(
OpenAIServingPooling
(
engine_client
,
engine_client
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
supported_tasks
=
supported_tasks
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template
=
resolved_chat_template
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
log_error_stack
=
args
.
log_error_stack
,
log_error_stack
=
args
.
log_error_stack
,
)
)
if
"encode"
in
supported_tasks
)
if
(
"token_embed"
in
supported_tasks
or
"token_classify"
in
supported_tasks
)
else
None
else
None
)
)
state
.
openai_serving_embedding
=
(
state
.
openai_serving_embedding
=
(
...
...
vllm/entrypoints/openai/protocol.py
View file @
f54f8512
...
@@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
...
@@ -1682,7 +1682,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
When using plugins IOProcessor plugins, the actual input is processed
When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data
by the plugin itself. Hence, we use a generic type for the request data
"""
"""
softmax
:
bool
=
Tru
e
activation
:
bool
=
Fals
e
embed_dtype
:
str
=
Field
(
embed_dtype
:
str
=
Field
(
default
=
"float32"
,
default
=
"float32"
,
...
@@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
...
@@ -1693,7 +1693,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
)
)
def
to_pooling_params
(
self
):
def
to_pooling_params
(
self
):
return
PoolingParams
(
task
=
"
encode"
,
softmax
=
self
.
softmax
)
return
PoolingParams
(
task
=
"
token_classify"
,
activation
=
self
.
activation
)
class
IOProcessorResponse
(
OpenAIBaseModel
,
Generic
[
T
]):
class
IOProcessorResponse
(
OpenAIBaseModel
,
Generic
[
T
]):
...
...
vllm/entrypoints/openai/serving_pooling.py
View file @
f54f8512
...
@@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig
...
@@ -35,6 +35,7 @@ from vllm.entrypoints.renderer import RenderConfig
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.entrypoints.utils
import
_validate_truncation_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.outputs
import
PoolingOutput
,
PoolingRequestOutput
from
vllm.tasks
import
SupportedTask
from
vllm.utils
import
merge_async_iterators
from
vllm.utils
import
merge_async_iterators
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -62,6 +63,7 @@ class OpenAIServingPooling(OpenAIServing):
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
models
:
OpenAIServingModels
,
models
:
OpenAIServingModels
,
*
,
*
,
supported_tasks
:
tuple
[
SupportedTask
,
...],
request_logger
:
RequestLogger
|
None
,
request_logger
:
RequestLogger
|
None
,
chat_template
:
str
|
None
,
chat_template
:
str
|
None
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
chat_template_content_format
:
ChatTemplateContentFormatOption
,
...
@@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -75,6 +77,7 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack
=
log_error_stack
,
log_error_stack
=
log_error_stack
,
)
)
self
.
supported_tasks
=
supported_tasks
self
.
chat_template
=
chat_template
self
.
chat_template
=
chat_template
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
chat_template_content_format
:
Final
=
chat_template_content_format
self
.
trust_request_chat_template
=
trust_request_chat_template
self
.
trust_request_chat_template
=
trust_request_chat_template
...
@@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing):
...
@@ -178,8 +181,17 @@ class OpenAIServingPooling(OpenAIServing):
try
:
try
:
pooling_params
=
request
.
to_pooling_params
()
pooling_params
=
request
.
to_pooling_params
()
if
"token_embed"
in
self
.
supported_tasks
:
pooling_task
=
"token_embed"
elif
"token_classify"
in
self
.
supported_tasks
:
pooling_task
=
"token_classify"
else
:
return
self
.
create_error_response
(
f
"pooling_task must be one of
{
self
.
supported_tasks
}
."
)
try
:
try
:
pooling_params
.
verify
(
"encode"
,
self
.
model_config
)
pooling_params
.
verify
(
pooling_task
,
self
.
model_config
)
except
ValueError
as
e
:
except
ValueError
as
e
:
return
self
.
create_error_response
(
str
(
e
))
return
self
.
create_error_response
(
str
(
e
))
...
...
vllm/model_executor/layers/pooler.py
View file @
f54f8512
...
@@ -64,66 +64,6 @@ class PoolingParamsUpdate:
...
@@ -64,66 +64,6 @@ class PoolingParamsUpdate:
params
.
requires_token_ids
=
self
.
requires_token_ids
params
.
requires_token_ids
=
self
.
requires_token_ids
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
@
staticmethod
def
for_encode
(
pooler_config
:
PoolerConfig
):
if
pooler_config
.
pooling_type
==
"STEP"
:
return
StepPooler
()
resolved_config
=
ResolvedPoolingConfig
(
task
=
"encode"
,
pooling_type
=
PoolingType
.
ALL
)
return
SimplePooler
.
from_config
(
resolved_config
)
@
staticmethod
def
for_embed
(
pooler_config
:
PoolerConfig
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"embed"
,
pooler_config
=
pooler_config
,
)
return
SimplePooler
.
from_config
(
resolved_config
)
@
staticmethod
def
for_classify
(
pooler_config
:
PoolerConfig
,
classifier
:
ClassifierFn
|
None
,
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"classify"
,
pooler_config
=
pooler_config
,
)
pooling
=
PoolingMethod
.
from_pooling_type
(
resolved_config
.
pooling_type
)
return
ClassifierPooler
(
pooling
=
pooling
,
classifier
=
classifier
,
)
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
"""Determine which pooling tasks are supported."""
raise
NotImplementedError
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
def
get_prompt_lens
(
def
get_prompt_lens
(
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
pooling_metadata
:
PoolingMetadata
,
...
@@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
...
@@ -237,7 +177,7 @@ class PoolingMethod(nn.Module, ABC):
class
CLSPool
(
PoolingMethod
):
class
CLSPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"
encode
"
,
"embed"
,
"classify"
,
"score"
}
return
{
"
token_embed"
,
"token_classify
"
,
"embed"
,
"classify"
,
"score"
}
def
forward_all
(
def
forward_all
(
self
,
self
,
...
@@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
...
@@ -253,7 +193,7 @@ class CLSPool(PoolingMethod):
class
LastPool
(
PoolingMethod
):
class
LastPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"
encode
"
,
"embed"
,
"classify"
,
"score"
}
return
{
"
token_embed"
,
"token_classify
"
,
"embed"
,
"classify"
,
"score"
}
def
forward_all
(
def
forward_all
(
self
,
self
,
...
@@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
...
@@ -265,7 +205,7 @@ class LastPool(PoolingMethod):
class
AllPool
(
PoolingMethod
):
class
AllPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"
encode
"
}
return
{
"
token_embed"
,
"token_classify
"
}
def
forward_all
(
def
forward_all
(
self
,
self
,
...
@@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
...
@@ -284,7 +224,7 @@ class AllPool(PoolingMethod):
class
MeanPool
(
PoolingMethod
):
class
MeanPool
(
PoolingMethod
):
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"
encode
"
,
"embed"
,
"classify"
,
"score"
}
return
{
"
token_embed"
,
"token_classify
"
,
"embed"
,
"classify"
,
"score"
}
def
forward_all
(
def
forward_all
(
self
,
self
,
...
@@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
...
@@ -398,6 +338,82 @@ class LambdaPoolerActivation(PoolerActivation):
return
self
.
fn
(
pooled_data
)
return
self
.
fn
(
pooled_data
)
class
Pooler
(
nn
.
Module
,
ABC
):
"""The interface required for all poolers used in pooling models in vLLM."""
@
staticmethod
def
for_token_embed
(
pooler_config
:
PoolerConfig
):
head
=
TokenEmbeddingPoolerHead
()
if
pooler_config
.
pooling_type
==
"STEP"
:
return
StepPooler
(
head
=
head
)
return
AllPooler
(
head
=
head
)
@
staticmethod
def
for_token_classify
(
pooler_config
:
PoolerConfig
,
classifier
:
ClassifierFn
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
head
=
TokenClassifierPoolerHead
(
classifier
=
classifier
,
act_fn
=
act_fn
)
if
pooler_config
.
pooling_type
==
"STEP"
:
return
StepPooler
(
head
=
head
)
return
AllPooler
(
head
=
head
)
@
staticmethod
def
for_embed
(
pooler_config
:
PoolerConfig
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"embed"
,
pooler_config
=
pooler_config
,
)
pooling
=
PoolingMethod
.
from_pooling_type
(
resolved_config
.
pooling_type
)
head
=
EmbeddingPoolerHead
()
return
SimplePooler
(
pooling
=
pooling
,
head
=
head
)
@
staticmethod
def
for_classify
(
pooler_config
:
PoolerConfig
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
resolved_config
=
ResolvedPoolingConfig
.
from_config
(
task
=
"classify"
,
pooler_config
=
pooler_config
,
)
pooling
=
PoolingMethod
.
from_pooling_type
(
resolved_config
.
pooling_type
)
return
ClassifierPooler
(
pooling
=
pooling
,
classifier
=
classifier
,
act_fn
=
act_fn
,
)
@
abstractmethod
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
"""Determine which pooling tasks are supported."""
raise
NotImplementedError
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
"""
Construct the updated pooling parameters to use for a supported task.
"""
return
PoolingParamsUpdate
()
@
abstractmethod
def
forward
(
self
,
hidden_states
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
raise
NotImplementedError
class
PoolerHead
(
nn
.
Module
):
class
PoolerHead
(
nn
.
Module
):
def
__init__
(
self
,
activation
:
PoolerActivation
)
->
None
:
def
__init__
(
self
,
activation
:
PoolerActivation
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
...
@@ -416,7 +432,6 @@ class EmbeddingPoolerHead(PoolerHead):
super
().
__init__
(
activation
=
PoolerNormalize
())
super
().
__init__
(
activation
=
PoolerNormalize
())
# Load ST projector if available
# Load ST projector if available
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
projector
:
nn
.
Module
|
None
=
(
self
.
projector
:
nn
.
Module
|
None
=
(
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
_load_st_projector
(
vllm_config
.
model_config
)
if
vllm_config
else
None
...
@@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
...
@@ -471,39 +486,6 @@ class EmbeddingPoolerHead(PoolerHead):
return
pooled_data
return
pooled_data
class
RewardPoolerHead
(
PoolerHead
):
def
__init__
(
self
)
->
None
:
super
().
__init__
(
activation
=
PoolerClassify
(
static_num_labels
=
False
))
vllm_config
=
get_current_vllm_config
()
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
forward
(
self
,
pooled_data
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
):
if
isinstance
(
pooled_data
,
list
):
pooled_data
=
[
p
.
to
(
self
.
head_dtype
)
for
p
in
pooled_data
]
else
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
# for softmax
flags
=
[
p
.
softmax
for
p
in
pooling_params
]
if
len
(
set
(
flags
))
==
1
:
if
flags
[
0
]:
pooled_data
=
self
.
activation
(
pooled_data
)
else
:
pooled_data
=
[
self
.
activation
(
vecs
)
if
f
else
vecs
for
vecs
,
f
in
zip
(
pooled_data
,
flags
)
]
return
pooled_data
class
SimplePooler
(
Pooler
):
class
SimplePooler
(
Pooler
):
"""A layer that pools specific information from hidden states.
"""A layer that pools specific information from hidden states.
...
@@ -513,20 +495,6 @@ class SimplePooler(Pooler):
...
@@ -513,20 +495,6 @@ class SimplePooler(Pooler):
3. Returns structured results as `PoolerOutput`.
3. Returns structured results as `PoolerOutput`.
"""
"""
@
classmethod
def
from_config
(
cls
,
pooler_config
:
ResolvedPoolingConfig
,
)
->
"SimplePooler"
:
pooling
=
PoolingMethod
.
from_pooling_type
(
pooler_config
.
pooling_type
)
if
pooler_config
.
task
==
"embed"
:
head
=
EmbeddingPoolerHead
()
elif
pooler_config
.
task
==
"encode"
:
head
=
RewardPoolerHead
()
else
:
raise
NotImplementedError
(
f
"Unknown task:
{
pooler_config
.
task
}
"
)
return
cls
(
pooling
,
head
)
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
PoolerHead
)
->
None
:
def
__init__
(
self
,
pooling
:
PoolingMethod
,
head
:
PoolerHead
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -549,58 +517,6 @@ class SimplePooler(Pooler):
...
@@ -549,58 +517,6 @@ class SimplePooler(Pooler):
return
pooled_data
return
pooled_data
class
StepPooler
(
Pooler
):
def
__init__
(
self
,
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
head
=
RewardPoolerHead
()
def
extract_states
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
)
->
list
[
torch
.
Tensor
]
|
torch
.
Tensor
:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
get_prompt_token_ids
(
pooling_metadata
)
pooled_data
=
list
[
torch
.
Tensor
]()
pooling_params
=
get_pooling_params
(
pooling_metadata
)
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
):
step_tag_id
=
pooling_param
.
step_tag_id
returned_token_ids
=
pooling_param
.
returned_token_ids
if
returned_token_ids
is
not
None
and
len
(
returned_token_ids
)
>
0
:
data
=
data
[:,
returned_token_ids
]
if
step_tag_id
is
not
None
:
data
=
data
[
token_id
==
step_tag_id
]
pooled_data
.
append
(
data
)
return
pooled_data
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"encode"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooled_data
=
self
.
head
(
pooled_data
,
pooling_metadata
)
return
pooled_data
class
ClassifierPooler
(
Pooler
):
class
ClassifierPooler
(
Pooler
):
"""A pooling layer for classification tasks.
"""A pooling layer for classification tasks.
...
@@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
...
@@ -611,26 +527,46 @@ class ClassifierPooler(Pooler):
"""
"""
@
staticmethod
@
staticmethod
def
act_fn_for_seq_cls
(
config
:
ModelConfig
):
def
act_fn_for_seq_cls
(
model_
config
:
ModelConfig
):
return
get_classification_activation_function
(
config
.
hf_config
)
return
get_classification_activation_function
(
model_
config
.
hf_config
)
@
staticmethod
@
staticmethod
def
act_fn_for_cross_encoder
(
config
:
ModelConfig
):
def
act_fn_for_cross_encoder
(
model_config
:
ModelConfig
):
return
get_cross_encoder_activation_function
(
config
.
hf_config
)
return
get_cross_encoder_activation_function
(
model_config
.
hf_config
)
@
staticmethod
def
resolve_act_fn
(
model_config
:
ModelConfig
,
static_num_labels
:
bool
=
True
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
):
if
isinstance
(
act_fn
,
str
):
if
act_fn
==
"classify"
:
return
ClassifierPooler
.
act_fn_for_seq_cls
(
model_config
)
elif
act_fn
==
"score"
:
return
ClassifierPooler
.
act_fn_for_cross_encoder
(
model_config
)
else
:
raise
ValueError
(
f
"act_fn [
{
act_fn
=
}
] not supported."
)
elif
act_fn
is
None
:
return
PoolerClassify
(
static_num_labels
=
static_num_labels
)
else
:
assert
callable
(
act_fn
)
return
act_fn
def
__init__
(
def
__init__
(
self
,
self
,
pooling
:
PoolingFn
,
pooling
:
PoolingFn
,
classifier
:
ClassifierFn
|
None
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
None
=
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
self
.
pooling
=
pooling
self
.
pooling
=
pooling
self
.
classifier
=
classifier
self
.
classifier
=
classifier
self
.
act_fn
=
act_fn
or
PoolerClassify
()
self
.
act_fn
=
self
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
True
,
act_fn
=
act_fn
)
self
.
logit_bias
:
float
|
None
=
(
self
.
logit_bias
:
float
|
None
=
(
vllm_config
.
model_config
.
pooler_config
.
logit_bias
vllm_config
.
model_config
.
pooler_config
.
logit_bias
)
)
...
@@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
...
@@ -672,6 +608,150 @@ class ClassifierPooler(Pooler):
return
scores
return
scores
class
TokenEmbeddingPoolerHead
(
EmbeddingPoolerHead
):
def
forward
(
self
,
pooled_data
:
torch
.
Tensor
,
pooling_param
:
PoolingParams
)
->
torch
.
Tensor
:
pooled_data
=
pooled_data
.
to
(
self
.
head_dtype
)
# pooled_data shape: [n_tokens, hidden_dimension]
# Apply ST projector
if
self
.
projector
is
not
None
:
pooled_data
=
self
.
projector
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
# for matryoshka representation
pooled_data
=
pooled_data
[...,
:
pooling_param
.
dimensions
]
# for normalize
if
pooling_param
.
normalize
:
pooled_data
=
self
.
activation
(
pooled_data
)
# pooled_data shape: [n_tokens, embedding_dimension]
return
pooled_data
class
TokenClassifierPoolerHead
(
nn
.
Module
):
def
__init__
(
self
,
classifier
:
ClassifierFn
|
None
,
act_fn
:
PoolerActivation
|
str
|
None
=
None
,
)
->
None
:
super
().
__init__
()
vllm_config
=
get_current_vllm_config
()
self
.
classifier
=
classifier
self
.
act_fn
=
ClassifierPooler
.
resolve_act_fn
(
vllm_config
.
model_config
,
static_num_labels
=
False
,
act_fn
=
act_fn
)
self
.
logit_bias
:
float
|
None
=
(
vllm_config
.
model_config
.
pooler_config
.
logit_bias
)
self
.
head_dtype
=
vllm_config
.
model_config
.
head_dtype
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_classify"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_param
:
PoolingParams
,
)
->
torch
.
Tensor
:
hidden_states
=
hidden_states
.
to
(
self
.
head_dtype
)
# hidden_states shape: [n_token, hidden_size]
if
self
.
classifier
is
not
None
:
scores
=
self
.
classifier
(
hidden_states
)
else
:
scores
=
hidden_states
# scores shape: [n_token, num_labels]
if
self
.
logit_bias
is
not
None
:
scores
-=
self
.
logit_bias
if
pooling_param
.
activation
:
scores
=
self
.
act_fn
(
scores
)
# scores shape: [n_token, num_labels]
return
scores
class
AllPooler
(
Pooler
):
def
__init__
(
self
,
head
:
nn
.
Module
|
PoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
head
=
head
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
pooled_data
=
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
return
pooled_data
class
StepPooler
(
Pooler
):
def
__init__
(
self
,
head
:
nn
.
Module
|
PoolerHead
)
->
None
:
super
().
__init__
()
self
.
pooling
=
AllPool
()
self
.
head
=
head
def
extract_states
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
)
->
torch
.
Tensor
|
list
[
torch
.
Tensor
]:
pooled_data_lst
=
self
.
pooling
(
hidden_states
,
pooling_metadata
)
prompt_token_ids
=
get_prompt_token_ids
(
pooling_metadata
)
pooled_data
=
list
[
torch
.
Tensor
]()
pooling_params
=
get_pooling_params
(
pooling_metadata
)
for
data
,
token_id
,
pooling_param
in
zip
(
pooled_data_lst
,
prompt_token_ids
,
pooling_params
):
step_tag_id
=
pooling_param
.
step_tag_id
returned_token_ids
=
pooling_param
.
returned_token_ids
if
returned_token_ids
is
not
None
and
len
(
returned_token_ids
)
>
0
:
data
=
data
[:,
returned_token_ids
]
if
step_tag_id
is
not
None
:
data
=
data
[
token_id
==
step_tag_id
]
pooled_data
.
append
(
data
)
return
pooled_data
def
get_supported_tasks
(
self
)
->
Set
[
PoolingTask
]:
return
{
"token_embed"
,
"token_classify"
}
def
get_pooling_updates
(
self
,
task
:
PoolingTask
)
->
PoolingParamsUpdate
:
return
PoolingParamsUpdate
(
requires_token_ids
=
True
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
|
list
[
torch
.
Tensor
],
pooling_metadata
:
PoolingMetadata
,
)
->
PoolerOutput
:
pooled_data
=
self
.
extract_states
(
hidden_states
,
pooling_metadata
)
pooling_params
=
get_pooling_params
(
pooling_metadata
)
assert
len
(
pooled_data
)
==
len
(
pooling_params
)
pooled_data
=
[
self
.
head
(
d
,
p
)
for
d
,
p
in
zip
(
pooled_data
,
pooling_params
)]
return
pooled_data
class
DispatchPooler
(
Pooler
):
class
DispatchPooler
(
Pooler
):
"""Dispatches calls to a sub-pooler based on the pooling task."""
"""Dispatches calls to a sub-pooler based on the pooling task."""
...
...
vllm/model_executor/models/adapters.py
View file @
f54f8512
...
@@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
...
@@ -250,7 +250,7 @@ def as_embedding_model(cls: _T) -> _T:
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"
encode
"
:
Pooler
.
for_
encode
(
pooler_config
),
"
token_embed
"
:
Pooler
.
for_
token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
},
},
)
)
...
@@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
...
@@ -279,11 +279,8 @@ def as_seq_cls_model(cls: _T) -> _T:
# Lazy import
# Lazy import
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.linear
import
ReplicatedLinear
from
vllm.model_executor.layers.pooler
import
(
from
vllm.model_executor.layers.pooler
import
(
ClassifierPooler
,
DispatchPooler
,
DispatchPooler
,
Pooler
,
Pooler
,
PoolingMethod
,
PoolingType
,
)
)
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.model_executor.models.interfaces
import
SupportsCrossEncoding
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
...
@@ -302,42 +299,29 @@ def as_seq_cls_model(cls: _T) -> _T:
model_config
.
hidden_size
,
model_config
.
hidden_size
,
config
.
num_labels
,
config
.
num_labels
,
bias
=
False
,
bias
=
False
,
params_dtype
=
torch
.
float32
,
params_dtype
=
vllm_config
.
model_config
.
head_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
return_bias
=
False
,
prefix
=
maybe_prefix
(
prefix
,
"score"
),
prefix
=
maybe_prefix
(
prefix
,
"score"
),
)
)
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
pooling_type_str
=
pooler_config
.
pooling_type
assert
pooling_type_str
is
not
None
pooling_type
=
PoolingType
[
pooling_type_str
]
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
"classify"
:
ClassifierPooler
(
pooler_config
,
classifier
=
self
.
score
pooling
=
PoolingMethod
.
from_pooling_type
(
pooling_type
),
classifier
=
self
.
_classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_seq_cls
(
vllm_config
.
model_config
),
),
),
"score"
:
ClassifierPooler
(
"classify"
:
Pooler
.
for_classify
(
pooling
=
PoolingMethod
.
from_pooling_type
(
pooling_type
),
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
classifier
=
self
.
_classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_cross_encoder
(
vllm_config
.
model_config
),
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
),
}
}
)
)
def
_classifier
(
self
,
x
:
torch
.
Tensor
):
x
,
_
=
self
.
score
(
x
.
float
())
return
x
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
...
@@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
...
@@ -393,7 +377,11 @@ def as_reward_model(cls: _T) -> _T:
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
)},
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
=
pooler_config
)
}
)
)
ModelForReward
.
__name__
=
_get_pooling_model_name
(
cls
.
__name__
,
"ForReward"
)
ModelForReward
.
__name__
=
_get_pooling_model_name
(
cls
.
__name__
,
"ForReward"
)
...
...
vllm/model_executor/models/bert.py
View file @
f54f8512
...
@@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
...
@@ -521,7 +521,7 @@ class BertEmbeddingModel(nn.Module, SupportsQuant):
def
_build_pooler
(
self
,
pooler_config
:
PoolerConfig
)
->
Pooler
:
def
_build_pooler
(
self
,
pooler_config
:
PoolerConfig
)
->
Pooler
:
return
DispatchPooler
(
return
DispatchPooler
(
{
{
"
encode
"
:
Pooler
.
for_
encode
(
pooler_config
),
"
token_embed
"
:
Pooler
.
for_
token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
}
)
)
...
@@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
...
@@ -724,7 +724,7 @@ class BertSpladeSparseEmbeddingModel(BertEmbeddingModel):
return
DispatchPooler
(
return
DispatchPooler
(
{
{
"
encode
"
:
Pooler
.
for_
encode
(
pooler_config
),
"
token_embed
"
:
Pooler
.
for_
token_embed
(
pooler_config
),
"embed"
:
SPLADESparsePooler
(
"embed"
:
SPLADESparsePooler
(
mlm_head
=
self
.
mlm_head
,
mlm_head
=
self
.
mlm_head
,
cls_token_id
=
cls_id
,
cls_token_id
=
cls_id
,
...
@@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
...
@@ -821,20 +821,16 @@ class BertForSequenceClassification(nn.Module, SupportsCrossEncoding, SupportsQu
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
"classify"
:
ClassifierPooler
(
pooling
=
self
.
bert
.
pooler
,
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_seq_cls
(
act_fn
=
"classify"
,
vllm_config
.
model_config
),
),
),
"score"
:
ClassifierPooler
(
"score"
:
ClassifierPooler
(
pooling
=
self
.
bert
.
pooler
,
pooling
=
self
.
bert
.
pooler
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_cross_encoder
(
vllm_config
.
model_config
),
),
),
}
}
)
)
...
@@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
...
@@ -891,7 +887,9 @@ class BertForTokenClassification(nn.Module):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
=
pooler_config
),
}
}
)
)
...
...
vllm/model_executor/models/bert_with_rope.py
View file @
f54f8512
...
@@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -695,20 +695,16 @@ class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
classifier
),
"classify"
:
ClassifierPooler
(
"classify"
:
ClassifierPooler
(
pooling
=
self
.
new
.
pooler
,
pooling
=
self
.
new
.
pooler
,
classifier
=
self
.
classifier
,
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_seq_cls
(
act_fn
=
"classify"
,
vllm_config
.
model_config
),
),
),
"score"
:
ClassifierPooler
(
"score"
:
ClassifierPooler
(
pooling
=
self
.
new
.
pooler
,
pooling
=
self
.
new
.
pooler
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_cross_encoder
(
vllm_config
.
model_config
),
),
),
}
}
)
)
...
...
vllm/model_executor/models/clip.py
View file @
f54f8512
...
@@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
...
@@ -837,7 +837,7 @@ class CLIPEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"
encode
"
:
Pooler
.
for_
encode
(
pooler_config
),
"
token_embed
"
:
Pooler
.
for_
token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
}
)
)
...
...
vllm/model_executor/models/gpt2.py
View file @
f54f8512
...
@@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -353,8 +353,15 @@ class GPT2ForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
),
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
}
)
)
...
...
vllm/model_executor/models/gritlm.py
View file @
f54f8512
...
@@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM):
...
@@ -239,7 +239,7 @@ class GritLM(LlamaForCausalLM):
if
pooler_config
is
not
None
:
if
pooler_config
is
not
None
:
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"
encode
"
:
Pooler
.
for_
encode
(
pooler_config
),
"
token_embed
"
:
Pooler
.
for_
token_embed
(
pooler_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
),
"embed"
:
GritLMPooler
(
vllm_config
.
model_config
),
}
}
)
)
vllm/model_executor/models/internlm2.py
View file @
f54f8512
...
@@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
...
@@ -444,7 +444,7 @@ class InternLM2ForRewardModel(InternLM2ForCausalLM):
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
"
encode"
:
Pooler
.
for_encode
(
pooler_config
)}
,
{
"
token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
)
def
forward
(
def
forward
(
...
...
vllm/model_executor/models/jamba.py
View file @
f54f8512
...
@@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM):
...
@@ -604,10 +604,14 @@ class JambaForSequenceClassification(JambaForCausalLM):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
,
classifier
=
self
.
score
),
"classify"
:
Pooler
.
for_classify
(
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
classifier
=
self
.
score
,
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
),
}
}
)
)
vllm/model_executor/models/jina_vl.py
View file @
f54f8512
...
@@ -97,9 +97,15 @@ class JinaVLForSequenceClassification(
...
@@ -97,9 +97,15 @@ class JinaVLForSequenceClassification(
self
.
score
=
JinaVLScorer
(
vllm_config
.
model_config
)
self
.
score
=
JinaVLScorer
(
vllm_config
.
model_config
)
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
),
pooler_config
,
classifier
=
self
.
score
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
),
),
"classify"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"classify"
),
"score"
:
Pooler
.
for_classify
(
pooler_config
,
classifier
=
self
.
score
,
act_fn
=
"score"
),
}
}
)
)
...
...
vllm/model_executor/models/modernbert.py
View file @
f54f8512
...
@@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -322,20 +322,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
"classify"
:
ClassifierPooler
(
pooler_config
,
classifier
=
self
.
classifier
pooling
=
self
.
pooling
,
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_seq_cls
(
vllm_config
.
model_config
),
),
"classify"
:
ClassifierPooler
(
pooling
=
self
.
pooling
,
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
),
"score"
:
ClassifierPooler
(
"score"
:
ClassifierPooler
(
pooling
=
self
.
pooling
,
pooling
=
self
.
pooling
,
classifier
=
self
.
classifier
,
act_fn
=
"score"
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_cross_encoder
(
vllm_config
.
model_config
),
),
),
}
}
)
)
...
@@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module):
...
@@ -421,7 +415,9 @@ class ModernBertForTokenClassification(nn.Module):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
=
pooler_config
),
}
}
)
)
...
...
vllm/model_executor/models/qwen2_rm.py
View file @
f54f8512
...
@@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
...
@@ -107,7 +107,7 @@ class Qwen2ForRewardModel(Qwen2RewardBaseModel):
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
"
encode"
:
Pooler
.
for_encode
(
pooler_config
)}
,
{
"
token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
)
...
@@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
...
@@ -120,4 +120,6 @@ class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
pooler_config
=
vllm_config
.
model_config
.
pooler_config
pooler_config
=
vllm_config
.
model_config
.
pooler_config
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
({
"encode"
:
Pooler
.
for_encode
(
pooler_config
)})
self
.
pooler
=
DispatchPooler
(
{
"token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
vllm/model_executor/models/roberta.py
View file @
f54f8512
...
@@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module):
...
@@ -105,15 +105,7 @@ class RobertaClassificationHead(nn.Module):
@
default_pooling_type
(
"CLS"
)
@
default_pooling_type
(
"CLS"
)
class
RobertaEmbeddingModel
(
BertEmbeddingModel
):
class
RobertaEmbeddingModel
(
BertEmbeddingModel
):
"""A model that uses Roberta to provide embedding functionalities.
"""A model that uses Roberta to provide embedding functionalities."""
This class encapsulates the BertModel and provides an interface for
embedding operations and customized pooling functions.
Attributes:
model: An instance of BertModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
""
):
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
super
().
__init__
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
...
@@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
...
@@ -212,20 +204,14 @@ class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
"classify"
:
ClassifierPooler
(
pooler_config
=
pooler_config
,
classifier
=
self
.
classifier
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_seq_cls
(
vllm_config
.
model_config
),
),
"classify"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
),
"score"
:
ClassifierPooler
(
"score"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"score"
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_cross_encoder
(
vllm_config
.
model_config
),
),
),
}
}
)
)
...
...
vllm/model_executor/models/terratorch.py
View file @
f54f8512
...
@@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
...
@@ -250,7 +250,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
assert
pooler_config
is
not
None
assert
pooler_config
is
not
None
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
"
encode"
:
Pooler
.
for_encode
(
pooler_config
)}
,
{
"
token_classify"
:
Pooler
.
for_token_classify
(
pooler_config
)}
)
)
def
get_input_embeddings
(
def
get_input_embeddings
(
...
...
vllm/model_executor/models/transformers_pooling.py
View file @
f54f8512
...
@@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
...
@@ -135,7 +135,7 @@ class TransformersEmbeddingModel(TransformersPoolingBase):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"
encode
"
:
Pooler
.
for_
encode
(
pooler_config
),
"
token_embed
"
:
Pooler
.
for_
token_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
"embed"
:
Pooler
.
for_embed
(
pooler_config
),
}
}
)
)
...
@@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
...
@@ -190,20 +190,14 @@ class TransformersForSequenceClassification(TransformersPoolingBase):
self
.
pooler
=
DispatchPooler
(
self
.
pooler
=
DispatchPooler
(
{
{
"encode"
:
Pooler
.
for_encode
(
pooler_config
),
"token_classify"
:
Pooler
.
for_token_classify
(
"classify"
:
ClassifierPooler
(
pooler_config
,
classifier
=
self
.
classifier
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_seq_cls
(
vllm_config
.
model_config
),
),
"classify"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"classify"
),
),
"score"
:
ClassifierPooler
(
"score"
:
ClassifierPooler
(
pooling
=
CLSPool
(),
pooling
=
CLSPool
(),
classifier
=
self
.
classifier
,
act_fn
=
"score"
classifier
=
self
.
classifier
,
act_fn
=
ClassifierPooler
.
act_fn_for_cross_encoder
(
vllm_config
.
model_config
),
),
),
}
}
)
)
...
...
vllm/pooling_params.py
View file @
f54f8512
...
@@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind
...
@@ -10,7 +10,7 @@ from vllm.sampling_params import RequestOutputKind
from
vllm.tasks
import
PoolingTask
from
vllm.tasks
import
PoolingTask
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
,
PoolerConfig
class
PoolingParams
(
class
PoolingParams
(
...
@@ -30,7 +30,6 @@ class PoolingParams(
...
@@ -30,7 +30,6 @@ class PoolingParams(
if model support matryoshka representation.
if model support matryoshka representation.
activation: Whether to apply activation function to
activation: Whether to apply activation function to
the classification outputs.
the classification outputs.
softmax: Whether to apply softmax to the reward outputs.
"""
"""
# --8<-- [start:common-pooling-params]
# --8<-- [start:common-pooling-params]
...
@@ -48,32 +47,19 @@ class PoolingParams(
...
@@ -48,32 +47,19 @@ class PoolingParams(
activation
:
bool
|
None
=
None
activation
:
bool
|
None
=
None
# --8<-- [end:classification-pooling-params]
# --8<-- [end:classification-pooling-params]
## for reward models
## for step pooling models
softmax
:
bool
|
None
=
None
step_tag_id
:
int
|
None
=
None
step_tag_id
:
int
|
None
=
None
returned_token_ids
:
list
[
int
]
|
None
=
None
returned_token_ids
:
list
[
int
]
|
None
=
None
## Internal use only
task
:
PoolingTask
|
None
=
None
task
:
PoolingTask
|
None
=
None
"""Internal use only."""
requires_token_ids
:
bool
=
False
requires_token_ids
:
bool
=
False
"""Internal use only."""
extra_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
extra_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
"""Internal use only."""
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
FINAL_ONLY
@
property
@
property
def
all_parameters
(
self
)
->
list
[
str
]:
def
all_parameters
(
self
)
->
list
[
str
]:
return
[
return
[
"dimensions"
,
"normalize"
,
"activation"
]
"dimensions"
,
"normalize"
,
"activation"
,
"softmax"
,
"step_tag_id"
,
"returned_token_ids"
,
]
@
property
@
property
def
valid_parameters
(
self
):
def
valid_parameters
(
self
):
...
@@ -81,7 +67,8 @@ class PoolingParams(
...
@@ -81,7 +67,8 @@ class PoolingParams(
"embed"
:
[
"dimensions"
,
"normalize"
],
"embed"
:
[
"dimensions"
,
"normalize"
],
"classify"
:
[
"activation"
],
"classify"
:
[
"activation"
],
"score"
:
[
"activation"
],
"score"
:
[
"activation"
],
"encode"
:
[
"softmax"
,
"step_tag_id"
,
"returned_token_ids"
],
"token_embed"
:
[
"dimensions"
,
"normalize"
],
"token_classify"
:
[
"activation"
],
}
}
def
clone
(
self
)
->
"PoolingParams"
:
def
clone
(
self
)
->
"PoolingParams"
:
...
@@ -100,7 +87,6 @@ class PoolingParams(
...
@@ -100,7 +87,6 @@ class PoolingParams(
# NOTE: Task validation needs to done against the model instance,
# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# which is not available in model config. So, it's not included
# in this method
# in this method
self
.
_merge_default_parameters
(
model_config
)
self
.
_merge_default_parameters
(
model_config
)
self
.
_set_default_parameters
(
model_config
)
self
.
_set_default_parameters
(
model_config
)
self
.
_verify_valid_parameters
()
self
.
_verify_valid_parameters
()
...
@@ -125,8 +111,34 @@ class PoolingParams(
...
@@ -125,8 +111,34 @@ class PoolingParams(
if
getattr
(
self
,
k
,
None
)
is
None
:
if
getattr
(
self
,
k
,
None
)
is
None
:
setattr
(
self
,
k
,
getattr
(
pooler_config
,
k
))
setattr
(
self
,
k
,
getattr
(
pooler_config
,
k
))
self
.
_verify_step_pooling
(
pooler_config
,
valid_parameters
)
def
_verify_step_pooling
(
self
,
pooler_config
:
"PoolerConfig"
,
valid_parameters
:
list
[
str
]
):
step_pooling_parameters
=
[
"step_tag_id"
,
"returned_token_ids"
]
if
pooler_config
.
pooling_type
!=
"STEP"
:
invalid_parameters
=
[]
for
k
in
step_pooling_parameters
:
if
getattr
(
self
,
k
,
None
)
is
not
None
:
invalid_parameters
.
append
(
k
)
if
invalid_parameters
:
raise
ValueError
(
f
"Task
{
self
.
task
}
only supports
{
valid_parameters
}
"
f
"parameters, does not support "
f
"
{
invalid_parameters
}
parameters"
)
else
:
for
k
in
step_pooling_parameters
:
if
getattr
(
pooler_config
,
k
,
None
)
is
None
:
continue
if
getattr
(
self
,
k
,
None
)
is
None
:
setattr
(
self
,
k
,
getattr
(
pooler_config
,
k
))
def
_set_default_parameters
(
self
,
model_config
:
Optional
[
"ModelConfig"
]):
def
_set_default_parameters
(
self
,
model_config
:
Optional
[
"ModelConfig"
]):
if
self
.
task
==
"
embed"
:
if
self
.
task
in
[
"embed"
,
"token_
embed"
]
:
if
self
.
normalize
is
None
:
if
self
.
normalize
is
None
:
self
.
normalize
=
True
self
.
normalize
=
True
...
@@ -150,13 +162,9 @@ class PoolingParams(
...
@@ -150,13 +162,9 @@ class PoolingParams(
elif
self
.
dimensions
<
1
:
elif
self
.
dimensions
<
1
:
raise
ValueError
(
"Dimensions must be greater than 0"
)
raise
ValueError
(
"Dimensions must be greater than 0"
)
elif
self
.
task
in
[
"classify"
,
"score"
]:
elif
self
.
task
in
[
"classify"
,
"score"
,
"token_classify"
]:
if
self
.
activation
is
None
:
if
self
.
activation
is
None
:
self
.
activation
=
True
self
.
activation
=
True
elif
self
.
task
==
"encode"
:
if
self
.
softmax
is
None
:
self
.
softmax
=
True
else
:
else
:
raise
ValueError
(
f
"Unknown pooling task:
{
self
.
task
}
"
)
raise
ValueError
(
f
"Unknown pooling task:
{
self
.
task
}
"
)
...
@@ -185,7 +193,6 @@ class PoolingParams(
...
@@ -185,7 +193,6 @@ class PoolingParams(
f
"normalize=
{
self
.
normalize
}
, "
f
"normalize=
{
self
.
normalize
}
, "
f
"dimensions=
{
self
.
dimensions
}
, "
f
"dimensions=
{
self
.
dimensions
}
, "
f
"activation=
{
self
.
activation
}
, "
f
"activation=
{
self
.
activation
}
, "
f
"softmax=
{
self
.
softmax
}
, "
f
"step_tag_id=
{
self
.
step_tag_id
}
, "
f
"step_tag_id=
{
self
.
step_tag_id
}
, "
f
"returned_token_ids=
{
self
.
returned_token_ids
}
, "
f
"returned_token_ids=
{
self
.
returned_token_ids
}
, "
f
"requires_token_ids=
{
self
.
requires_token_ids
}
, "
f
"requires_token_ids=
{
self
.
requires_token_ids
}
, "
...
...
vllm/tasks.py
View file @
f54f8512
...
@@ -5,7 +5,7 @@ from typing import Literal, get_args
...
@@ -5,7 +5,7 @@ from typing import Literal, get_args
GenerationTask
=
Literal
[
"generate"
,
"transcription"
]
GenerationTask
=
Literal
[
"generate"
,
"transcription"
]
GENERATION_TASKS
=
get_args
(
GenerationTask
)
GENERATION_TASKS
=
get_args
(
GenerationTask
)
PoolingTask
=
Literal
[
"encode"
,
"embed"
,
"classify"
,
"score"
]
PoolingTask
=
Literal
[
"embed"
,
"classify"
,
"score"
,
"token_embed"
,
"token_classify"
]
POOLING_TASKS
=
get_args
(
PoolingTask
)
POOLING_TASKS
=
get_args
(
PoolingTask
)
SupportedTask
=
Literal
[
GenerationTask
,
PoolingTask
]
SupportedTask
=
Literal
[
GenerationTask
,
PoolingTask
]
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment