Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5eb36575
Commit
5eb36575
authored
Apr 26, 2026
by
khluu
Browse files
Revert "[Frontend] Remove frontend pooling multi task support. (#37861)"
This reverts commit
d2e2e856
.
parent
4d51588e
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
40 additions
and
55 deletions
+40
-55
vllm/entrypoints/pooling/factories.py
vllm/entrypoints/pooling/factories.py
+24
-33
vllm/entrypoints/pooling/pooling/serving.py
vllm/entrypoints/pooling/pooling/serving.py
+6
-4
vllm/entrypoints/pooling/scoring/serving.py
vllm/entrypoints/pooling/scoring/serving.py
+1
-7
vllm/entrypoints/pooling/utils.py
vllm/entrypoints/pooling/utils.py
+2
-6
vllm/pooling_params.py
vllm/pooling_params.py
+7
-0
vllm/tasks.py
vllm/tasks.py
+0
-5
No files found.
vllm/entrypoints/pooling/factories.py
View file @
5eb36575
...
@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateConfig
...
@@ -10,7 +10,7 @@ from vllm.entrypoints.chat_utils import ChatTemplateConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.plugins.io_processors
import
has_io_processor
from
vllm.plugins.io_processors
import
has_io_processor
from
vllm.renderers
import
BaseRenderer
from
vllm.renderers
import
BaseRenderer
from
vllm.tasks
import
POOLING_TASKS
,
SCORE_TYPE_MAP
,
SupportedTask
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
from
.base.io_processor
import
PoolingIOProcessor
from
.base.io_processor
import
PoolingIOProcessor
from
.utils
import
enable_scoring_api
from
.utils
import
enable_scoring_api
...
@@ -43,24 +43,23 @@ def init_pooling_io_processors(
...
@@ -43,24 +43,23 @@ def init_pooling_io_processors(
)
->
dict
[
str
,
PoolingIOProcessor
]:
)
->
dict
[
str
,
PoolingIOProcessor
]:
model_config
=
vllm_config
.
model_config
model_config
=
vllm_config
.
model_config
processors
:
dict
[
str
,
type
[
PoolingIOProcessor
]]
=
{}
processors
:
dict
[
str
,
type
[
PoolingIOProcessor
]]
=
{}
pooling_task
=
model_config
.
get_pooling_task
(
supported_tasks
)
if
pooling_task
==
"classify"
:
if
"classify"
in
supported_tasks
:
from
.classify.io_processor
import
ClassifyIOProcessor
from
.classify.io_processor
import
ClassifyIOProcessor
processors
[
"classify"
]
=
ClassifyIOProcessor
processors
[
"classify"
]
=
ClassifyIOProcessor
if
pooling_task
==
"token_classify"
:
if
"token_classify"
in
supported_tasks
:
from
.classify.io_processor
import
TokenClassifyIOProcessor
from
.classify.io_processor
import
TokenClassifyIOProcessor
processors
[
"token_classify"
]
=
TokenClassifyIOProcessor
processors
[
"token_classify"
]
=
TokenClassifyIOProcessor
if
pooling_task
==
"embed"
:
if
"embed"
in
supported_tasks
:
from
.embed.io_processor
import
EmbedIOProcessor
from
.embed.io_processor
import
EmbedIOProcessor
processors
[
"embed"
]
=
EmbedIOProcessor
processors
[
"embed"
]
=
EmbedIOProcessor
if
pooling_task
==
"token_embed"
:
if
"token_embed"
in
supported_tasks
:
from
.embed.io_processor
import
TokenEmbedIOProcessor
from
.embed.io_processor
import
TokenEmbedIOProcessor
processors
[
"token_embed"
]
=
TokenEmbedIOProcessor
processors
[
"token_embed"
]
=
TokenEmbedIOProcessor
...
@@ -72,15 +71,15 @@ def init_pooling_io_processors(
...
@@ -72,15 +71,15 @@ def init_pooling_io_processors(
from
.pooling.io_processor
import
PluginWithIOProcessorPlugins
from
.pooling.io_processor
import
PluginWithIOProcessorPlugins
processors
[
"plugin"
]
=
PluginWithIOProcessorPlugins
processors
[
"plugin"
]
=
PluginWithIOProcessorPlugins
elif
pooling_task
==
"plugin"
:
elif
"plugin"
in
supported_tasks
:
from
.pooling.io_processor
import
PluginWithoutIOProcessorPlugins
from
.pooling.io_processor
import
PluginWithoutIOProcessorPlugins
processors
[
"plugin"
]
=
PluginWithoutIOProcessorPlugins
processors
[
"plugin"
]
=
PluginWithoutIOProcessorPlugins
if
enable_scoring_api
(
supported_tasks
,
model_config
):
if
enable_scoring_api
(
supported_tasks
,
model_config
):
score_type
=
model_config
.
score_type
from
.scoring.io_processor
import
ScoringIOProcessors
from
.scoring.io_processor
import
ScoringIOProcessors
score_type
:
str
|
None
=
SCORE_TYPE_MAP
.
get
(
pooling_task
,
None
)
# type: ignore[arg-type]
if
score_type
is
not
None
and
score_type
in
ScoringIOProcessors
:
if
score_type
is
not
None
and
score_type
in
ScoringIOProcessors
:
processors
[
score_type
]
=
ScoringIOProcessors
[
score_type
]
processors
[
score_type
]
=
ScoringIOProcessors
[
score_type
]
...
@@ -141,10 +140,6 @@ def init_pooling_state(
...
@@ -141,10 +140,6 @@ def init_pooling_state(
request_logger
:
RequestLogger
|
None
,
request_logger
:
RequestLogger
|
None
,
supported_tasks
:
tuple
[
"SupportedTask"
,
...],
supported_tasks
:
tuple
[
"SupportedTask"
,
...],
):
):
model_config
=
engine_client
.
model_config
if
model_config
is
None
:
return
from
vllm.entrypoints.chat_utils
import
load_chat_template
from
vllm.entrypoints.chat_utils
import
load_chat_template
from
vllm.tasks
import
POOLING_TASKS
from
vllm.tasks
import
POOLING_TASKS
...
@@ -153,14 +148,8 @@ def init_pooling_state(
...
@@ -153,14 +148,8 @@ def init_pooling_state(
from
.pooling.serving
import
ServingPooling
from
.pooling.serving
import
ServingPooling
from
.scoring.serving
import
ServingScores
from
.scoring.serving
import
ServingScores
model_config
=
engine_client
.
model_config
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
resolved_chat_template
=
load_chat_template
(
args
.
chat_template
)
pooling_task
=
model_config
.
get_pooling_task
(
supported_tasks
)
chat_template_config
=
ChatTemplateConfig
(
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
)
state
.
serving_pooling
=
(
state
.
serving_pooling
=
(
(
(
...
@@ -169,7 +158,9 @@ def init_pooling_state(
...
@@ -169,7 +158,9 @@ def init_pooling_state(
state
.
openai_serving_models
,
state
.
openai_serving_models
,
supported_tasks
=
supported_tasks
,
supported_tasks
=
supported_tasks
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template_config
=
chat_template_config
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
)
)
)
)
if
any
(
t
in
supported_tasks
for
t
in
POOLING_TASKS
)
if
any
(
t
in
supported_tasks
for
t
in
POOLING_TASKS
)
...
@@ -180,9 +171,11 @@ def init_pooling_state(
...
@@ -180,9 +171,11 @@ def init_pooling_state(
engine_client
,
engine_client
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template_config
=
chat_template_config
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
)
)
if
pooling_task
==
"embed"
if
"embed"
in
supported_tasks
else
None
else
None
)
)
state
.
serving_classification
=
(
state
.
serving_classification
=
(
...
@@ -190,18 +183,21 @@ def init_pooling_state(
...
@@ -190,18 +183,21 @@ def init_pooling_state(
engine_client
,
engine_client
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template_config
=
chat_template_config
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
)
)
if
pooling_task
==
"classify"
if
"classify"
in
supported_tasks
else
None
else
None
)
)
state
.
serving_scores
=
(
state
.
serving_scores
=
(
ServingScores
(
ServingScores
(
engine_client
,
engine_client
,
state
.
openai_serving_models
,
state
.
openai_serving_models
,
supported_tasks
=
supported_tasks
,
request_logger
=
request_logger
,
request_logger
=
request_logger
,
chat_template_config
=
chat_template_config
,
chat_template
=
resolved_chat_template
,
chat_template_content_format
=
args
.
chat_template_content_format
,
trust_request_chat_template
=
args
.
trust_request_chat_template
,
enable_flash_late_interaction
=
getattr
(
enable_flash_late_interaction
=
getattr
(
args
,
"enable_flash_late_interaction"
,
True
args
,
"enable_flash_late_interaction"
,
True
),
),
...
@@ -218,12 +214,7 @@ def get_pooling_invocation_types(
...
@@ -218,12 +214,7 @@ def get_pooling_invocation_types(
# NOTE: Items defined earlier take higher priority
# NOTE: Items defined earlier take higher priority
invocation_types
:
list
[
tuple
[
RequestType
,
tuple
[
GetHandlerFn
,
EndpointFn
]]]
=
[]
invocation_types
:
list
[
tuple
[
RequestType
,
tuple
[
GetHandlerFn
,
EndpointFn
]]]
=
[]
if
model_config
is
None
:
if
"embed"
in
supported_tasks
:
return
invocation_types
pooling_task
=
model_config
.
get_pooling_task
(
supported_tasks
)
if
pooling_task
==
"embed"
:
from
.embed.api_router
import
create_embedding
,
embedding
from
.embed.api_router
import
create_embedding
,
embedding
from
.embed.protocol
import
EmbeddingRequest
from
.embed.protocol
import
EmbeddingRequest
...
@@ -231,7 +222,7 @@ def get_pooling_invocation_types(
...
@@ -231,7 +222,7 @@ def get_pooling_invocation_types(
(
EmbeddingRequest
,
(
embedding
,
create_embedding
)),
(
EmbeddingRequest
,
(
embedding
,
create_embedding
)),
]
]
if
pooling_task
==
"classify"
:
if
"classify"
in
supported_tasks
:
from
.classify.api_router
import
classify
,
create_classify
from
.classify.api_router
import
classify
,
create_classify
from
.classify.protocol
import
ClassificationRequest
from
.classify.protocol
import
ClassificationRequest
...
...
vllm/entrypoints/pooling/pooling/serving.py
View file @
5eb36575
...
@@ -78,15 +78,17 @@ class ServingPooling(PoolingServingBase):
...
@@ -78,15 +78,17 @@ class ServingPooling(PoolingServingBase):
# plugin task uses io_processor.parse_request to verify inputs
# plugin task uses io_processor.parse_request to verify inputs
if
pooling_task
!=
"plugin"
and
pooling_task
!=
self
.
pooling_task
:
if
pooling_task
!=
"plugin"
and
pooling_task
!=
self
.
pooling_task
:
if
pooling_task
not
in
self
.
supported_task
s
:
if
pooling_task
not
in
self
.
io_processor
s
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported task:
{
pooling_task
!
r
}
"
f
"Unsupported task:
{
pooling_task
!
r
}
"
f
"Supported tasks:
{
self
.
supported_tasks
}
"
f
"Supported tasks:
{
self
.
supported_tasks
}
"
)
)
else
:
else
:
raise
ValueError
(
logger
.
warning_once
(
"Try switching the model's pooling_task "
"Pooling multitask support is deprecated and will be removed "
f
"via --pooler-config.task
{
request
.
task
}
."
"in v0.20. When the default pooling task is not what you want, you "
"need to manually specify it via --pooler-config.task %s. "
,
pooling_task
,
)
)
if
pooling_task
==
"plugin"
and
"plugin"
not
in
self
.
io_processors
:
if
pooling_task
==
"plugin"
and
"plugin"
not
in
self
.
io_processors
:
...
...
vllm/entrypoints/pooling/scoring/serving.py
View file @
5eb36575
...
@@ -8,7 +8,6 @@ from vllm.engine.protocol import EngineClient
...
@@ -8,7 +8,6 @@ from vllm.engine.protocol import EngineClient
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.entrypoints.openai.engine.protocol
import
UsageInfo
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.outputs
import
PoolingRequestOutput
,
ScoringRequestOutput
from
vllm.tasks
import
SCORE_TYPE_MAP
,
SupportedTask
from
vllm.v1.pool.late_interaction
import
(
from
vllm.v1.pool.late_interaction
import
(
build_late_interaction_doc_params
,
build_late_interaction_doc_params
,
build_late_interaction_query_params
,
build_late_interaction_query_params
,
...
@@ -39,15 +38,10 @@ class ServingScores(PoolingServing):
...
@@ -39,15 +38,10 @@ class ServingScores(PoolingServing):
self
,
self
,
engine_client
:
EngineClient
,
engine_client
:
EngineClient
,
*
args
,
*
args
,
supported_tasks
:
tuple
[
SupportedTask
,
...],
enable_flash_late_interaction
:
bool
=
True
,
enable_flash_late_interaction
:
bool
=
True
,
**
kwargs
,
**
kwargs
,
):
):
pooling_task
=
engine_client
.
model_config
.
get_pooling_task
(
supported_tasks
)
self
.
io_processor_name
:
str
=
engine_client
.
model_config
.
score_type
score_type
=
SCORE_TYPE_MAP
.
get
(
pooling_task
,
None
)
# type: ignore[arg-type]
assert
score_type
is
not
None
self
.
io_processor_name
:
str
=
score_type
self
.
enable_flash_late_interaction
=
(
self
.
enable_flash_late_interaction
=
(
self
.
io_processor_name
==
"late-interaction"
self
.
io_processor_name
==
"late-interaction"
and
enable_flash_late_interaction
and
enable_flash_late_interaction
...
...
vllm/entrypoints/pooling/utils.py
View file @
5eb36575
...
@@ -141,14 +141,10 @@ def enable_scoring_api(
...
@@ -141,14 +141,10 @@ def enable_scoring_api(
supported_tasks
:
tuple
[
"SupportedTask"
,
...],
supported_tasks
:
tuple
[
"SupportedTask"
,
...],
model_config
:
ModelConfig
|
None
=
None
,
model_config
:
ModelConfig
|
None
=
None
,
)
->
bool
:
)
->
bool
:
if
model_config
is
None
:
if
any
(
t
in
supported_tasks
for
t
in
(
"embed"
,
"token_embed"
)):
return
False
pooling_task
=
model_config
.
get_pooling_task
(
supported_tasks
)
if
pooling_task
in
(
"embed"
,
"token_embed"
):
return
True
return
True
if
pooling_task
==
"classify"
:
if
model_config
is
not
None
and
"classify"
in
supported_tasks
:
num_labels
=
getattr
(
model_config
.
hf_config
,
"num_labels"
,
0
)
num_labels
=
getattr
(
model_config
.
hf_config
,
"num_labels"
,
0
)
if
num_labels
!=
1
:
if
num_labels
!=
1
:
logger
.
debug_once
(
"Scoring API is only enabled for num_labels == 1."
)
logger
.
debug_once
(
"Scoring API is only enabled for num_labels == 1."
)
...
...
vllm/pooling_params.py
View file @
5eb36575
...
@@ -87,6 +87,13 @@ class PoolingParams(
...
@@ -87,6 +87,13 @@ class PoolingParams(
return
deepcopy
(
self
)
return
deepcopy
(
self
)
def
verify
(
self
,
model_config
:
ModelConfig
)
->
None
:
def
verify
(
self
,
model_config
:
ModelConfig
)
->
None
:
if
self
.
task
==
"score"
:
logger
.
warning_once
(
"`score` task is deprecated and will be removed in v0.20. "
"Please use `classify` instead."
)
self
.
task
=
"classify"
# plugin task uses io_processor.parse_request to verify inputs,
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
# skipping PoolingParams verify
if
self
.
task
==
"plugin"
:
if
self
.
task
==
"plugin"
:
...
...
vllm/tasks.py
View file @
5eb36575
...
@@ -16,11 +16,6 @@ PoolingTask = Literal[
...
@@ -16,11 +16,6 @@ PoolingTask = Literal[
POOLING_TASKS
:
tuple
[
PoolingTask
,
...]
=
get_args
(
PoolingTask
)
POOLING_TASKS
:
tuple
[
PoolingTask
,
...]
=
get_args
(
PoolingTask
)
ScoreType
=
Literal
[
"bi-encoder"
,
"cross-encoder"
,
"late-interaction"
]
ScoreType
=
Literal
[
"bi-encoder"
,
"cross-encoder"
,
"late-interaction"
]
SCORE_TYPE_MAP
:
dict
[
PoolingTask
,
ScoreType
]
=
{
"embed"
:
"bi-encoder"
,
"classify"
:
"cross-encoder"
,
"token_embed"
:
"late-interaction"
,
}
FrontendTask
=
Literal
[
"render"
]
FrontendTask
=
Literal
[
"render"
]
FRONTEND_TASKS
:
tuple
[
FrontendTask
,
...]
=
get_args
(
FrontendTask
)
FRONTEND_TASKS
:
tuple
[
FrontendTask
,
...]
=
get_args
(
FrontendTask
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment