Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4fe58953
Unverified
Commit
4fe58953
authored
Oct 28, 2025
by
Nick Hill
Committed by
GitHub
Oct 28, 2025
Browse files
[AsyncScheduling] Make async overlap work with logprobs (#27615)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
111faf11
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
10 deletions
+65
-10
tests/conftest.py
tests/conftest.py
+8
-2
tests/v1/e2e/test_async_sched_and_preempt.py
tests/v1/e2e/test_async_sched_and_preempt.py
+33
-4
vllm/v1/outputs.py
vllm/v1/outputs.py
+9
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+15
-4
No files found.
tests/conftest.py
View file @
4fe58953
...
@@ -831,8 +831,9 @@ class VllmRunner:
...
@@ -831,8 +831,9 @@ class VllmRunner:
images
:
PromptImageInput
|
None
=
None
,
images
:
PromptImageInput
|
None
=
None
,
videos
:
PromptVideoInput
|
None
=
None
,
videos
:
PromptVideoInput
|
None
=
None
,
audios
:
PromptAudioInput
|
None
=
None
,
audios
:
PromptAudioInput
|
None
=
None
,
return_logprobs
:
bool
=
False
,
**
kwargs
:
Any
,
**
kwargs
:
Any
,
)
->
list
[
tuple
[
list
[
list
[
int
]],
list
[
str
]]]:
)
->
list
[
tuple
[
list
[
list
[
int
]],
list
[
str
]]]
|
tuple
[
list
,
list
]
:
inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
videos
=
videos
,
audios
=
audios
)
inputs
=
self
.
get_inputs
(
prompts
,
images
=
images
,
videos
=
videos
,
audios
=
audios
)
req_outputs
=
self
.
llm
.
generate
(
req_outputs
=
self
.
llm
.
generate
(
...
@@ -840,18 +841,23 @@ class VllmRunner:
...
@@ -840,18 +841,23 @@ class VllmRunner:
)
)
outputs
:
list
[
tuple
[
list
[
list
[
int
]],
list
[
str
]]]
=
[]
outputs
:
list
[
tuple
[
list
[
list
[
int
]],
list
[
str
]]]
=
[]
logprobs
=
[]
for
req_output
in
req_outputs
:
for
req_output
in
req_outputs
:
prompt_str
=
req_output
.
prompt
prompt_str
=
req_output
.
prompt
prompt_ids
=
req_output
.
prompt_token_ids
prompt_ids
=
req_output
.
prompt_token_ids
req_sample_output_ids
:
list
[
list
[
int
]]
=
[]
req_sample_output_ids
:
list
[
list
[
int
]]
=
[]
req_sample_output_strs
:
list
[
str
]
=
[]
req_sample_output_strs
:
list
[
str
]
=
[]
req_logprobs
=
[]
for
sample
in
req_output
.
outputs
:
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_str
=
sample
.
text
output_ids
=
list
(
sample
.
token_ids
)
output_ids
=
list
(
sample
.
token_ids
)
req_sample_output_ids
.
append
(
prompt_ids
+
output_ids
)
req_sample_output_ids
.
append
(
prompt_ids
+
output_ids
)
req_sample_output_strs
.
append
((
prompt_str
or
""
)
+
output_str
)
req_sample_output_strs
.
append
((
prompt_str
or
""
)
+
output_str
)
if
sample
.
logprobs
:
req_logprobs
.
extend
(
sample
.
logprobs
)
outputs
.
append
((
req_sample_output_ids
,
req_sample_output_strs
))
outputs
.
append
((
req_sample_output_ids
,
req_sample_output_strs
))
return
outputs
logprobs
.
append
(
req_logprobs
)
return
outputs
if
not
return_logprobs
else
(
outputs
,
logprobs
)
@
staticmethod
@
staticmethod
def
_final_steps_generate_w_logprobs
(
def
_final_steps_generate_w_logprobs
(
...
...
tests/v1/e2e/test_async_sched_and_preempt.py
View file @
4fe58953
...
@@ -6,6 +6,7 @@ import pytest
...
@@ -6,6 +6,7 @@ import pytest
import
torch._dynamo.config
as
dynamo_config
import
torch._dynamo.config
as
dynamo_config
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.logprobs
import
Logprob
from
...conftest
import
VllmRunner
from
...conftest
import
VllmRunner
from
...models.utils
import
check_outputs_equal
from
...models.utils
import
check_outputs_equal
...
@@ -32,6 +33,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
...
@@ -32,6 +33,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
# dict(min_tokens=20),
# dict(min_tokens=20),
dict
(
presence_penalty
=-
1.0
),
dict
(
presence_penalty
=-
1.0
),
dict
(
bad_words
=
[
"the"
,
" the"
]),
dict
(
bad_words
=
[
"the"
,
" the"
]),
dict
(
logprobs
=
2
),
dict
(
logprobs
=
2
,
presence_penalty
=-
1.0
),
]
]
default_params
=
dict
(
default_params
=
dict
(
...
@@ -77,29 +80,33 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
...
@@ -77,29 +80,33 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
**
default_params
,
**
override_params
**
default_params
,
**
override_params
),
),
return_logprobs
=
True
,
)
)
)
)
if
not
outputs
:
if
not
outputs
:
# First check that the different parameter configs
# First check that the different parameter configs
# actually result in different output.
# actually result in different output.
for
other_test
,
params
in
zip
(
for
(
other_test
_outs
,
other_test_logprobs
)
,
params
in
zip
(
results
[
1
:],
sampling_param_tests
[
1
:]
results
[
1
:],
sampling_param_tests
[
1
:]
):
):
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
check_outputs_equal
(
check_outputs_equal
(
outputs_0_lst
=
results
[
0
],
outputs_0_lst
=
results
[
0
]
[
0
]
,
outputs_1_lst
=
other_test
,
outputs_1_lst
=
other_test
_outs
,
name_0
=
f
"baseline params=
{
params
}
"
,
name_0
=
f
"baseline params=
{
params
}
"
,
name_1
=
f
"other params=
{
params
}
"
,
name_1
=
f
"other params=
{
params
}
"
,
)
)
assert
_all_logprobs_match
(
results
[
0
][
1
],
other_test_logprobs
)
outputs
.
append
((
test_config
,
results
))
outputs
.
append
((
test_config
,
results
))
baseline_config
,
baseline_tests
=
outputs
[
0
]
baseline_config
,
baseline_tests
=
outputs
[
0
]
for
test_config
,
test_outputs
in
outputs
[
1
:]:
for
test_config
,
test_outputs
in
outputs
[
1
:]:
for
base_outs
,
test_outs
,
params
in
zip
(
for
(
base_outs
,
base_logprobs
),
(
test_outs
,
test_logprobs
)
,
params
in
zip
(
baseline_tests
,
test_outputs
,
sampling_param_tests
baseline_tests
,
test_outputs
,
sampling_param_tests
):
):
check_outputs_equal
(
check_outputs_equal
(
...
@@ -108,5 +115,27 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
...
@@ -108,5 +115,27 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
name_0
=
f
"baseline=[
{
baseline_config
}
], params=
{
params
}
"
,
name_0
=
f
"baseline=[
{
baseline_config
}
], params=
{
params
}
"
,
name_1
=
f
"config=[
{
test_config
}
], params=
{
params
}
"
,
name_1
=
f
"config=[
{
test_config
}
], params=
{
params
}
"
,
)
)
assert
_all_logprobs_match
(
base_logprobs
,
test_logprobs
)
print
(
f
"PASSED: config=[
{
test_config
}
], params=
{
params
}
"
)
print
(
f
"PASSED: config=[
{
test_config
}
], params=
{
params
}
"
)
def
_all_logprobs_match
(
req_a
,
req_b
)
->
bool
:
return
(
req_a
==
req_b
or
len
(
req_a
)
==
len
(
req_b
)
and
all
(
len
(
seq_a
)
==
len
(
seq_b
)
and
all
(
_logprobs_match
(
a
,
b
)
for
a
,
b
in
zip
(
seq_a
,
seq_b
))
for
seq_a
,
seq_b
in
zip
(
req_a
,
req_b
)
)
)
def
_logprobs_match
(
lps_a
:
dict
[
int
,
Logprob
],
lps_b
:
dict
[
int
,
Logprob
])
->
bool
:
return
len
(
lps_a
)
==
len
(
lps_b
)
and
all
(
a
.
decoded_token
==
b
.
decoded_token
and
a
.
rank
==
b
.
rank
and
a
.
logprob
==
pytest
.
approx
(
b
.
logprob
,
rel
=
1e-3
,
abs
=
1e-6
)
for
a
,
b
in
((
lps_a
[
x
],
lps_b
[
x
])
for
x
in
lps_a
)
)
vllm/v1/outputs.py
View file @
4fe58953
...
@@ -59,6 +59,15 @@ class LogprobsTensors(NamedTuple):
...
@@ -59,6 +59,15 @@ class LogprobsTensors(NamedTuple):
cu_num_generated_tokens
,
cu_num_generated_tokens
,
)
)
def
to_cpu_nonblocking
(
self
)
->
"LogprobsTensors"
:
if
self
.
logprob_token_ids
.
device
.
type
==
"cpu"
:
return
self
return
LogprobsTensors
(
self
.
logprob_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
logprobs
.
to
(
"cpu"
,
non_blocking
=
True
),
self
.
selected_token_ranks
.
to
(
"cpu"
,
non_blocking
=
True
),
)
@
staticmethod
@
staticmethod
def
empty_cpu
(
def
empty_cpu
(
num_positions
:
int
,
num_tokens_per_position
:
int
num_positions
:
int
,
num_tokens_per_position
:
int
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4fe58953
...
@@ -164,6 +164,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -164,6 +164,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self
,
self
,
model_runner_output
:
ModelRunnerOutput
,
model_runner_output
:
ModelRunnerOutput
,
sampled_token_ids
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
logprobs_tensors
:
torch
.
Tensor
|
None
,
invalid_req_indices
:
list
[
int
],
invalid_req_indices
:
list
[
int
],
async_output_copy_stream
:
torch
.
cuda
.
Stream
,
async_output_copy_stream
:
torch
.
cuda
.
Stream
,
):
):
...
@@ -176,6 +177,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -176,6 +177,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
# Keep a reference to the device tensor to avoid it being
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
# deallocated until we finish copying it to the host.
self
.
_sampled_token_ids
=
sampled_token_ids
self
.
_sampled_token_ids
=
sampled_token_ids
self
.
_logprobs_tensors
=
logprobs_tensors
# Initiate the copy on a separate stream, but do not synchronize it.
# Initiate the copy on a separate stream, but do not synchronize it.
default_stream
=
torch
.
cuda
.
current_stream
()
default_stream
=
torch
.
cuda
.
current_stream
()
...
@@ -184,6 +186,11 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -184,6 +186,11 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self
.
sampled_token_ids_cpu
=
self
.
_sampled_token_ids
.
to
(
self
.
sampled_token_ids_cpu
=
self
.
_sampled_token_ids
.
to
(
"cpu"
,
non_blocking
=
True
"cpu"
,
non_blocking
=
True
)
)
self
.
_logprobs_tensors_cpu
=
(
self
.
_logprobs_tensors
.
to_cpu_nonblocking
()
if
self
.
_logprobs_tensors
else
None
)
self
.
async_copy_ready_event
.
record
()
self
.
async_copy_ready_event
.
record
()
def
get_output
(
self
)
->
ModelRunnerOutput
:
def
get_output
(
self
)
->
ModelRunnerOutput
:
...
@@ -193,7 +200,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -193,7 +200,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
"""
"""
self
.
async_copy_ready_event
.
synchronize
()
self
.
async_copy_ready_event
.
synchronize
()
# Release the device tensor once the copy has completed
# Release the device tensors once the copy has completed.
del
self
.
_logprobs_tensors
del
self
.
_sampled_token_ids
del
self
.
_sampled_token_ids
valid_sampled_token_ids
=
self
.
sampled_token_ids_cpu
.
tolist
()
valid_sampled_token_ids
=
self
.
sampled_token_ids_cpu
.
tolist
()
...
@@ -202,6 +210,10 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
...
@@ -202,6 +210,10 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
output
=
self
.
_model_runner_output
output
=
self
.
_model_runner_output
output
.
sampled_token_ids
=
valid_sampled_token_ids
output
.
sampled_token_ids
=
valid_sampled_token_ids
if
self
.
_logprobs_tensors_cpu
:
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
# for async sched + spec decode + logprobs compatibility.
output
.
logprobs
=
self
.
_logprobs_tensors_cpu
.
tolists
()
return
output
return
output
...
@@ -2334,11 +2346,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2334,11 +2346,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_accepted_tokens
[
-
1
]
+
len
(
sampled_ids
)
cu_num_accepted_tokens
[
-
1
]
+
len
(
sampled_ids
)
)
)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_lists
=
(
logprobs_lists
=
(
logprobs_tensors
.
tolists
(
cu_num_accepted_tokens
)
logprobs_tensors
.
tolists
(
cu_num_accepted_tokens
)
if
logprobs_tensors
is
not
None
if
not
self
.
use_async_scheduling
and
logprobs_tensors
is
not
None
else
None
else
None
)
)
...
@@ -2664,6 +2674,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2664,6 +2674,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
async_output
=
AsyncGPUModelRunnerOutput
(
async_output
=
AsyncGPUModelRunnerOutput
(
model_runner_output
=
output
,
model_runner_output
=
output
,
sampled_token_ids
=
sampler_output
.
sampled_token_ids
,
sampled_token_ids
=
sampler_output
.
sampled_token_ids
,
logprobs_tensors
=
sampler_output
.
logprobs_tensors
,
invalid_req_indices
=
invalid_req_indices
,
invalid_req_indices
=
invalid_req_indices
,
async_output_copy_stream
=
self
.
async_output_copy_stream
,
async_output_copy_stream
=
self
.
async_output_copy_stream
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment