Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45badd05
Unverified
Commit
45badd05
authored
Jul 18, 2025
by
Cyrus Leung
Committed by
GitHub
Jul 18, 2025
Browse files
[Core] Set pooling params based on task and model (#21128)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
4adc66f6
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
3 deletions
+45
-3
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+13
-1
vllm/v1/worker/tpu_worker.py
vllm/v1/worker/tpu_worker.py
+4
-0
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+13
-1
vllm/worker/pooling_model_runner.py
vllm/worker/pooling_model_runner.py
+15
-1
No files found.
vllm/v1/worker/tpu_model_runner.py
View file @
45badd05
...
...
@@ -3,7 +3,7 @@
import
bisect
import
gc
import
time
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
cast
,
get_args
from
unittest.mock
import
patch
import
numpy
as
np
...
...
@@ -25,10 +25,12 @@ from vllm.logger import init_logger
from
vllm.lora.layers
import
BaseLayerWithLoRA
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.model_executor.model_loader.tpu
import
TPUModelLoader
from
vllm.model_executor.models.interfaces_base
import
is_pooling_model
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
BatchedTensorInputs
,
MultiModalKwargs
,
PlaceholderRange
)
from
vllm.multimodal.utils
import
group_mm_inputs_by_modality
from
vllm.pooling_params
import
PoolingTask
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
LayerBlockType
,
cdiv
,
is_pin_memory_available
,
prev_power_of_2
)
...
...
@@ -483,6 +485,16 @@ class TPUModelRunner(LoRAModelRunnerMixin):
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
def
get_supported_pooling_tasks
(
self
)
->
list
[
PoolingTask
]:
model
=
self
.
get_model
()
if
not
is_pooling_model
(
model
):
return
[]
return
[
task
for
task
in
get_args
(
PoolingTask
)
if
model
.
pooler
.
get_pooling_updates
(
task
)
]
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
...
...
vllm/v1/worker/tpu_worker.py
View file @
45badd05
...
...
@@ -19,6 +19,7 @@ from vllm.logger import init_logger
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.platforms
import
current_platform
from
vllm.pooling_params
import
PoolingTask
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
cdiv
from
vllm.v1.attention.backends.pallas
import
TPU_HEAD_SIZE_ALIGNMENT
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -275,6 +276,9 @@ class TPUWorker:
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model_runner
.
get_model
()
def
get_supported_pooling_tasks
(
self
)
->
list
[
PoolingTask
]:
return
self
.
model_runner
.
get_supported_pooling_tasks
()
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
return
self
.
model_runner
.
get_kv_cache_spec
()
...
...
vllm/worker/model_runner_base.py
View file @
45badd05
...
...
@@ -4,7 +4,7 @@
import
dataclasses
from
abc
import
ABC
,
abstractmethod
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Type
,
TypeVar
)
TypeVar
,
get_args
)
import
torch
import
torch.nn
as
nn
...
...
@@ -12,6 +12,8 @@ import torch.nn as nn
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.models.interfaces_base
import
is_pooling_model
from
vllm.pooling_params
import
PoolingTask
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
if
TYPE_CHECKING
:
...
...
@@ -223,6 +225,16 @@ class ModelRunnerBase(ABC, Generic[T]):
def
get_model
(
self
)
->
nn
.
Module
:
raise
NotImplementedError
def
get_supported_pooling_tasks
(
self
)
->
list
[
PoolingTask
]:
model
=
self
.
get_model
()
if
not
is_pooling_model
(
model
):
return
[]
return
[
task
for
task
in
get_args
(
PoolingTask
)
if
model
.
pooler
.
get_pooling_updates
(
task
)
]
def
execute_model
(
self
,
model_input
:
T
,
...
...
vllm/worker/pooling_model_runner.py
View file @
45badd05
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
dataclasses
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
,
Union
,
cast
import
torch
...
...
@@ -10,6 +10,7 @@ from vllm.config import VllmConfig
from
vllm.distributed
import
get_pp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.models.interfaces_base
import
VllmModelForPooling
from
vllm.model_executor.pooling_metadata
import
PoolingMetadata
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.pooling_params
import
PoolingParams
...
...
@@ -195,7 +196,20 @@ class PoolingModelRunner(
seq_groups
:
List
[
Tuple
[
List
[
int
],
PoolingParams
]]
=
[]
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
pooling_params
=
seq_group_metadata
.
pooling_params
assert
pooling_params
is
not
None
assert
pooling_params
.
task
is
not
None
,
(
"You did not set `task` in the API"
)
to_update
=
(
cast
(
VllmModelForPooling
,
self
.
model
).
pooler
.
get_pooling_updates
(
pooling_params
.
task
))
assert
to_update
is
not
None
,
(
f
"
{
pooling_params
.
task
=
}
is not supported by the model"
)
to_update
.
apply
(
pooling_params
)
seq_groups
.
append
((
seq_ids
,
pooling_params
))
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment