Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da436f86
Unverified
Commit
da436f86
authored
Jan 04, 2026
by
Nick Hill
Committed by
GitHub
Jan 04, 2026
Browse files
[Minor] Small pooler output processing optimization (#31667)
Signed-off-by:
njhill
<
nickhill123@gmail.com
>
parent
f099cd55
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
11 deletions
+8
-11
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+8
-11
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
da436f86
...
...
@@ -263,7 +263,6 @@ class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput):
async_output_copy_stream
:
torch
.
cuda
.
Stream
,
):
self
.
_model_runner_output
=
model_runner_output
self
.
_finished_mask
=
finished_mask
# Event on the copy stream so we can synchronize the non-blocking copy.
self
.
async_copy_ready_event
=
torch
.
Event
()
...
...
@@ -276,11 +275,15 @@ class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput):
default_stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
async_output_copy_stream
):
async_output_copy_stream
.
wait_stream
(
default_stream
)
self
.
_
raw_pooler_output_cpu
=
json_map_leaves
(
raw_pooler_output_cpu
=
json_map_leaves
(
lambda
x
:
None
if
x
is
None
else
x
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
_raw_pooler_output
,
)
self
.
async_copy_ready_event
.
record
()
self
.
_model_runner_output
.
pooler_output
=
[
out
if
include
else
None
for
out
,
include
in
zip
(
raw_pooler_output_cpu
,
finished_mask
)
]
def
get_output
(
self
)
->
ModelRunnerOutput
:
"""Copy the device tensors to the host and return a ModelRunnerOutput.
...
...
@@ -290,11 +293,6 @@ class AsyncGPUPoolingModelRunnerOutput(AsyncModelRunnerOutput):
# Release the device tensors once the copy has completed.
del
self
.
_raw_pooler_output
self
.
_model_runner_output
.
pooler_output
=
[
out
if
include
else
None
for
out
,
include
in
zip
(
self
.
_raw_pooler_output_cpu
,
self
.
_finished_mask
)
]
return
self
.
_model_runner_output
...
...
@@ -2537,8 +2535,7 @@ class GPUModelRunner(
model
=
cast
(
VllmModelForPooling
,
self
.
model
)
raw_pooler_output
:
PoolerOutput
=
model
.
pooler
(
hidden_states
=
hidden_states
,
pooling_metadata
=
pooling_metadata
,
hidden_states
=
hidden_states
,
pooling_metadata
=
pooling_metadata
)
finished_mask
=
[
...
...
@@ -2568,12 +2565,12 @@ class GPUModelRunner(
lambda
x
:
None
if
x
is
None
else
x
.
to
(
"cpu"
,
non_blocking
=
True
),
raw_pooler_output
,
)
self
.
_sync_device
()
model_runner_output
.
pooler_output
=
[
out
if
include
else
None
for
out
,
include
in
zip
(
raw_pooler_output
,
finished_mask
)
]
self
.
_sync_device
()
return
model_runner_output
def
_pad_for_sequence_parallelism
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment