Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
86ac7bcf
Unverified
Commit
86ac7bcf
authored
Feb 27, 2026
by
Woosuk Kwon
Committed by
GitHub
Feb 27, 2026
Browse files
[Model Runner V2] Support pooling models (#35120)
Signed-off-by:
Woosuk Kwon
<
woosuk@inferact.ai
>
parent
405f28d3
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
209 additions
and
14 deletions
+209
-14
vllm/v1/worker/gpu/async_utils.py
vllm/v1/worker/gpu/async_utils.py
+36
-0
vllm/v1/worker/gpu/input_batch.py
vllm/v1/worker/gpu/input_batch.py
+32
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+90
-14
vllm/v1/worker/gpu/pool/__init__.py
vllm/v1/worker/gpu/pool/__init__.py
+0
-0
vllm/v1/worker/gpu/pool/pooling_runner.py
vllm/v1/worker/gpu/pool/pooling_runner.py
+45
-0
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+6
-0
No files found.
vllm/v1/worker/gpu/async_utils.py
View file @
86ac7bcf
...
...
@@ -70,6 +70,42 @@ class AsyncOutput(AsyncModelRunnerOutput):
return
self
.
model_runner_output
class
AsyncPoolingOutput
(
AsyncModelRunnerOutput
):
def
__init__
(
self
,
model_runner_output
:
ModelRunnerOutput
,
pooler_output
:
torch
.
Tensor
,
is_valid
:
torch
.
Tensor
|
None
,
main_stream
:
torch
.
cuda
.
Stream
,
copy_stream
:
torch
.
cuda
.
Stream
,
copy_event
:
torch
.
cuda
.
Event
,
):
self
.
model_runner_output
=
model_runner_output
self
.
pooler_output
=
pooler_output
self
.
is_valid
=
is_valid
self
.
copy_event
=
copy_event
with
stream
(
copy_stream
,
main_stream
):
copy_stream
.
wait_stream
(
main_stream
)
self
.
pooler_output_cpu
=
self
.
pooler_output
.
to
(
"cpu"
,
non_blocking
=
True
)
if
self
.
is_valid
is
not
None
:
self
.
is_valid_cpu
=
self
.
is_valid
.
to
(
"cpu"
,
non_blocking
=
True
)
else
:
self
.
is_valid_cpu
=
None
self
.
copy_event
.
record
(
copy_stream
)
def
get_output
(
self
)
->
ModelRunnerOutput
:
self
.
copy_event
.
synchronize
()
pooler_output
=
self
.
pooler_output_cpu
.
unbind
(
dim
=
0
)
if
self
.
is_valid_cpu
is
not
None
:
is_valid_cpu
=
self
.
is_valid_cpu
.
tolist
()
for
i
,
is_valid
in
enumerate
(
is_valid_cpu
):
if
not
is_valid
:
pooler_output
[
i
]
=
None
self
.
model_runner_output
.
pooler_output
=
pooler_output
return
self
.
model_runner_output
def
async_copy_to_np
(
x
:
torch
.
Tensor
)
->
np
.
ndarray
:
return
x
.
to
(
"cpu"
,
non_blocking
=
True
).
numpy
()
...
...
vllm/v1/worker/gpu/input_batch.py
View file @
86ac7bcf
...
...
@@ -499,6 +499,38 @@ def post_update(
)
@
triton
.
jit
def
_post_update_pool_kernel
(
idx_mapping_ptr
,
num_computed_tokens_ptr
,
query_start_loc_ptr
,
):
batch_id
=
tl
.
program_id
(
0
)
query_start
=
tl
.
load
(
query_start_loc_ptr
+
batch_id
)
query_end
=
tl
.
load
(
query_start_loc_ptr
+
batch_id
+
1
)
query_len
=
query_end
-
query_start
req_state_idx
=
tl
.
load
(
idx_mapping_ptr
+
batch_id
)
num_computed
=
tl
.
load
(
num_computed_tokens_ptr
+
req_state_idx
)
tl
.
store
(
num_computed_tokens_ptr
+
req_state_idx
,
num_computed
+
query_len
)
def
post_update_pool
(
# [num_reqs]
idx_mapping
:
torch
.
Tensor
,
# [max_num_reqs]
num_computed_tokens
:
torch
.
Tensor
,
# [num_reqs + 1]
query_start_loc
:
torch
.
Tensor
,
)
->
None
:
num_reqs
=
idx_mapping
.
shape
[
0
]
_post_update_pool_kernel
[(
num_reqs
,)](
idx_mapping
,
num_computed_tokens
,
query_start_loc
,
)
@
triton
.
jit
def
_expand_idx_mapping_kernel
(
idx_mapping_ptr
,
...
...
vllm/v1/worker/gpu/model_runner.py
View file @
86ac7bcf
...
...
@@ -38,13 +38,14 @@ from vllm.logger import init_logger
from
vllm.model_executor.model_loader
import
get_model_loader
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
from
vllm.tasks
import
SupportedTask
from
vllm.utils.mem_utils
import
DeviceMemoryProfiler
,
format_gib
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
DraftTokenIds
,
ModelRunnerOutput
from
vllm.v1.worker.cp_utils
import
check_attention_cp_compatibility
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
,
AsyncPoolingOutput
from
vllm.v1.worker.gpu.attn_utils
import
(
build_slot_mappings_by_layer
,
get_kv_cache_spec
,
...
...
@@ -66,6 +67,7 @@ from vllm.v1.worker.gpu.input_batch import (
expand_idx_mapping
,
get_num_sampled_and_rejected
,
post_update
,
post_update_pool
,
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
)
...
...
@@ -77,6 +79,7 @@ from vllm.v1.worker.gpu.kv_connector import (
from
vllm.v1.worker.gpu.lora_utils
import
LoraState
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.model_states
import
ModelState
from
vllm.v1.worker.gpu.pool.pooling_runner
import
PoolingRunner
from
vllm.v1.worker.gpu.pp_utils
import
pp_broadcast
,
pp_receive
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
from
vllm.v1.worker.gpu.sample.prompt_logprob
import
PromptLogprobsWorker
...
...
@@ -119,7 +122,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
kv_cache_dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
self
.
cache_config
.
cache_dtype
]
self
.
is_pooling_model
=
False
self
.
vocab_size
=
self
.
model_config
.
get_vocab_size
()
self
.
max_model_len
=
self
.
model_config
.
max_model_len
...
...
@@ -217,6 +219,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# KV Connector if configured.
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
# Pooling models.
self
.
is_pooling_model
=
self
.
model_config
.
runner_type
==
"pooling"
self
.
pooling_runner
:
PoolingRunner
|
None
=
None
# For transferring state from execute_model to subsequent sample_tokens call.
self
.
execute_model_state
:
tuple
|
None
=
None
...
...
@@ -224,9 +230,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
max_model_len
=
max_model_len
self
.
req_states
.
max_model_len
=
max_model_len
@
staticmethod
def
get_supported_tasks
()
->
tuple
[
str
]:
return
(
"generate"
,)
def
get_supported_tasks
(
self
)
->
tuple
[
SupportedTask
,
...]:
tasks
:
list
[
SupportedTask
]
=
[]
if
self
.
model_config
.
runner_type
==
"generate"
:
tasks
.
append
(
"generate"
)
if
self
.
pooling_runner
is
not
None
:
tasks
.
extend
(
self
.
pooling_runner
.
get_supported_pooling_tasks
())
return
tuple
(
tasks
)
def
load_model
(
self
,
*
args
,
**
kwargs
)
->
None
:
time_before_load
=
time
.
perf_counter
()
...
...
@@ -263,6 +273,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Initialize the components that require the model.
self
.
model_state
=
ModelState
(
self
.
vllm_config
,
self
.
model
,
self
.
device
)
if
self
.
is_pooling_model
:
self
.
pooling_runner
=
PoolingRunner
(
self
.
model
)
def
get_model
(
self
)
->
nn
.
Module
:
return
self
.
model
...
...
@@ -388,16 +400,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
expanded_local_pos
,
)
@
torch
.
inference_mode
()
def
_dummy_pooler_run
(
self
,
hidden_states
:
torch
.
Tensor
)
->
None
:
assert
self
.
pooling_runner
is
not
None
self
.
pooling_runner
.
dummy_pooler_run
(
hidden_states
)
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
hidden_states
,
sample_hidden_states
=
self
.
_dummy_run
(
self
.
max_num_tokens
,
skip_attn
=
True
)
# Only run sampler on last PP rank (non-last ranks return None).
# Only run sampler
/pooler
on last PP rank (non-last ranks return None).
if
self
.
is_last_pp_rank
:
assert
sample_hidden_states
is
not
None
self
.
_dummy_sampler_run
(
sample_hidden_states
)
if
self
.
pooling_runner
is
None
:
self
.
_dummy_sampler_run
(
sample_hidden_states
)
else
:
self
.
_dummy_pooler_run
(
hidden_states
)
if
self
.
speculator
is
not
None
:
num_tokens_across_dp
=
make_num_tokens_across_dp
(
...
...
@@ -505,7 +525,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
new_req_data
in
scheduler_output
.
scheduled_new_reqs
:
assert
new_req_data
.
prompt_token_ids
is
not
None
assert
new_req_data
.
prefill_token_ids
is
not
None
assert
new_req_data
.
sampling_params
is
not
None
req_id
=
new_req_data
.
req_id
prompt_len
=
len
(
new_req_data
.
prompt_token_ids
)
self
.
req_states
.
add_request
(
...
...
@@ -523,14 +542,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self
.
block_tables
.
append_block_ids
(
req_index
,
new_req_data
.
block_ids
,
overwrite
=
True
)
self
.
sampler
.
add_request
(
req_index
,
prompt_len
,
new_req_data
.
sampling_params
)
self
.
prompt_logprobs_worker
.
add_request
(
req_id
,
req_index
,
new_req_data
.
sampling_params
)
self
.
lora_state
.
add_request
(
req_id
,
req_index
,
new_req_data
.
lora_request
)
if
new_req_data
.
sampling_params
is
not
None
:
self
.
sampler
.
add_request
(
req_index
,
prompt_len
,
new_req_data
.
sampling_params
)
self
.
prompt_logprobs_worker
.
add_request
(
req_id
,
req_index
,
new_req_data
.
sampling_params
)
if
scheduler_output
.
scheduled_new_reqs
:
self
.
req_states
.
apply_staged_writes
()
self
.
sampler
.
apply_staged_writes
()
...
...
@@ -1083,3 +1104,58 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def
take_draft_token_ids
(
self
)
->
DraftTokenIds
|
None
:
return
self
.
draft_tokens_handler
.
get_draft_tokens
()
@
torch
.
inference_mode
()
def
pool
(
self
)
->
AsyncPoolingOutput
|
ModelRunnerOutput
|
None
:
if
self
.
execute_model_state
is
None
:
# The prior execute_model call must have failed.
return
None
input_batch
,
_
,
_
,
_
,
hidden_states
,
_
,
kv_connector_output
=
(
self
.
execute_model_state
)
self
.
execute_model_state
=
None
if
not
self
.
is_last_pp_rank
:
self
.
postprocess_pool
(
input_batch
)
return
None
assert
self
.
pooling_runner
is
not
None
pooler_output
,
is_valid
=
self
.
pooling_runner
.
pool
(
hidden_states
,
input_batch
,
self
.
req_states
)
self
.
postprocess_pool
(
input_batch
)
# Build the model runner output.
model_runner_output
=
ModelRunnerOutput
(
req_ids
=
input_batch
.
req_ids
,
req_id_to_index
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
)},
kv_connector_output
=
kv_connector_output
,
)
async_output
=
AsyncPoolingOutput
(
model_runner_output
=
model_runner_output
,
pooler_output
=
pooler_output
,
is_valid
=
is_valid
,
main_stream
=
self
.
main_stream
,
copy_stream
=
self
.
output_copy_stream
,
copy_event
=
self
.
output_copy_event
,
)
if
self
.
use_async_scheduling
:
return
async_output
return
async_output
.
get_output
()
def
postprocess_pool
(
self
,
input_batch
:
InputBatch
)
->
None
:
# Update the number of computed tokens.
post_update_pool
(
input_batch
.
idx_mapping
,
self
.
req_states
.
num_computed_tokens
.
gpu
,
input_batch
.
query_start_loc
,
)
# Update the number of computed prefill tokens.
idx_mapping_np
=
input_batch
.
idx_mapping_np
computed_prefill
=
self
.
req_states
.
num_computed_prefill_tokens
computed_prefill
[
idx_mapping_np
]
+=
input_batch
.
num_scheduled_tokens
np
.
minimum
(
computed_prefill
,
self
.
req_states
.
prefill_len
.
np
,
out
=
computed_prefill
)
vllm/v1/worker/gpu/pool/__init__.py
0 → 100644
View file @
86ac7bcf
vllm/v1/worker/gpu/pool/pooling_runner.py
0 → 100644
View file @
86ac7bcf
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
cast
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm.model_executor.models
import
VllmModelForPooling
,
is_pooling_model
from
vllm.tasks
import
PoolingTask
from
vllm.v1.worker.gpu.input_batch
import
InputBatch
from
vllm.v1.worker.gpu.states
import
RequestState
# NOTE(woosuk): Currently, this class only supports the "LAST" pooling task
# on decoder-only models. How to support other pooling tasks and models
# is to be determined.
class
PoolingRunner
:
def
__init__
(
self
,
model
:
nn
.
Module
):
self
.
model
=
cast
(
VllmModelForPooling
,
model
)
def
get_supported_pooling_tasks
(
self
)
->
list
[
PoolingTask
]:
if
not
is_pooling_model
(
self
.
model
):
return
[]
assert
"embed"
in
self
.
model
.
pooler
.
get_supported_tasks
()
return
[
"embed"
]
def
pool
(
self
,
hidden_states
:
torch
.
Tensor
,
input_batch
:
InputBatch
,
req_states
:
RequestState
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
# TODO(woosuk): Support different types of pooling tasks.
last_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
# TODO(woosuk): Make normalization optional.
last_hidden_states
=
F
.
normalize
(
last_hidden_states
,
p
=
2
,
dim
=-
1
)
prompt_len
=
req_states
.
prompt_len
.
gpu
[
input_batch
.
idx_mapping
]
is_valid
=
input_batch
.
seq_lens
==
prompt_len
return
last_hidden_states
,
is_valid
def
dummy_pooler_run
(
self
,
hidden_states
:
torch
.
Tensor
)
->
None
:
F
.
normalize
(
hidden_states
,
p
=
2
,
dim
=-
1
)
return
vllm/v1/worker/gpu_worker.py
View file @
86ac7bcf
...
...
@@ -700,6 +700,12 @@ class Worker(WorkerBase):
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
if
(
self
.
use_v2_model_runner
and
self
.
model_runner
.
is_pooling_model
and
output
is
None
):
output
=
self
.
model_runner
.
pool
()
# type: ignore
if
isinstance
(
output
,
ModelRunnerOutput
|
AsyncModelRunnerOutput
|
NoneType
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment