Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8ed5421a
Unverified
Commit
8ed5421a
authored
Mar 07, 2025
by
Nick Hill
Committed by
GitHub
Mar 07, 2025
Browse files
[V1] Eagerly remove finished requests from the batch (#14388)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
c6359e8c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
58 additions
and
16 deletions
+58
-16
tests/v1/engine/test_engine_core.py
tests/v1/engine/test_engine_core.py
+10
-0
tests/v1/engine/test_engine_core_client.py
tests/v1/engine/test_engine_core_client.py
+2
-2
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+10
-1
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+3
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+4
-2
vllm/v1/metrics/loggers.py
vllm/v1/metrics/loggers.py
+8
-4
vllm/v1/outputs.py
vllm/v1/outputs.py
+10
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-3
vllm/v1/worker/tpu_model_runner.py
vllm/v1/worker/tpu_model_runner.py
+5
-1
No files found.
tests/v1/engine/test_engine_core.py
View file @
8ed5421a
...
...
@@ -102,14 +102,24 @@ def test_engine_core(monkeypatch):
engine_core
.
add_request
(
req
)
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
1
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
assert
engine_core
.
scheduler
.
has_unfinished_requests
()
assert
not
engine_core
.
scheduler
.
has_finished_requests
()
_
=
engine_core
.
step
()
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
1
assert
engine_core
.
scheduler
.
has_unfinished_requests
()
assert
not
engine_core
.
scheduler
.
has_finished_requests
()
engine_core
.
abort_requests
([
request_id
])
assert
len
(
engine_core
.
scheduler
.
waiting
)
==
0
assert
len
(
engine_core
.
scheduler
.
running
)
==
0
assert
not
engine_core
.
scheduler
.
has_unfinished_requests
()
assert
engine_core
.
scheduler
.
has_finished_requests
()
_
=
engine_core
.
step
()
assert
not
engine_core
.
scheduler
.
has_unfinished_requests
()
assert
not
engine_core
.
scheduler
.
has_finished_requests
()
# Add, step, abort 1 of the 3.
req0
=
make_request
()
...
...
tests/v1/engine/test_engine_core_client.py
View file @
8ed5421a
...
...
@@ -50,7 +50,7 @@ def loop_until_done(client: EngineCoreClient, outputs: dict):
engine_core_outputs
=
client
.
get_output
().
outputs
if
len
(
engine_core_outputs
)
==
0
:
break
continue
all_finished
=
True
for
out
in
engine_core_outputs
:
...
...
@@ -68,7 +68,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
engine_core_outputs
=
(
await
client
.
get_output_async
()).
outputs
if
len
(
engine_core_outputs
)
==
0
:
break
continue
all_finished
=
True
for
out
in
engine_core_outputs
:
...
...
vllm/v1/core/scheduler.py
View file @
8ed5421a
...
...
@@ -682,7 +682,8 @@ class Scheduler:
assert
RequestStatus
.
is_finished
(
finished_status
)
if
isinstance
(
request_ids
,
str
):
request_ids
=
(
request_ids
,
)
request_ids
=
set
(
request_ids
)
else
:
request_ids
=
set
(
request_ids
)
for
req_id
in
request_ids
:
request
=
self
.
requests
.
get
(
req_id
)
...
...
@@ -714,6 +715,14 @@ class Scheduler:
def
has_unfinished_requests
(
self
)
->
bool
:
return
self
.
get_num_unfinished_requests
()
>
0
def
has_finished_requests
(
self
)
->
bool
:
return
len
(
self
.
finished_req_ids
)
>
0
def
has_requests
(
self
):
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return
self
.
has_unfinished_requests
()
or
self
.
has_finished_requests
()
def
get_num_unscheduled_requests
(
self
)
->
int
:
"""Number of requests that are not being processed by the executor."""
return
self
.
get_num_unfinished_requests
()
-
len
(
self
.
scheduled_req_ids
)
...
...
vllm/v1/engine/async_llm.py
View file @
8ed5421a
...
...
@@ -253,13 +253,14 @@ class AsyncLLM(EngineClient):
while
True
:
# 1) Pull EngineCoreOutputs from the EngineCore.
outputs
=
await
self
.
engine_core
.
get_output_async
()
num_outputs
=
len
(
outputs
.
outputs
)
iteration_stats
=
IterationStats
()
if
self
.
log_stats
else
None
iteration_stats
=
IterationStats
()
if
(
self
.
log_stats
and
num_outputs
)
else
None
# Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long.
num_outputs
=
len
(
outputs
.
outputs
)
if
num_outputs
<=
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE
:
slices
=
(
outputs
.
outputs
,
)
else
:
...
...
@@ -313,7 +314,6 @@ class AsyncLLM(EngineClient):
return
assert
scheduler_stats
is
not
None
assert
iteration_stats
is
not
None
for
stat_logger
in
self
.
stat_loggers
:
stat_logger
.
record
(
scheduler_stats
=
scheduler_stats
,
iteration_stats
=
iteration_stats
)
...
...
vllm/v1/engine/core.py
View file @
8ed5421a
...
...
@@ -153,7 +153,9 @@ class EngineCore:
def
step
(
self
)
->
EngineCoreOutputs
:
"""Schedule, execute, and make output."""
if
not
self
.
scheduler
.
has_unfinished_requests
():
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if
not
self
.
scheduler
.
has_requests
():
return
EngineCoreOutputs
(
outputs
=
[],
scheduler_stats
=
self
.
scheduler
.
make_stats
(),
...
...
@@ -335,7 +337,7 @@ class EngineCoreProc(EngineCore):
# Loop until process is sent a SIGINT or SIGTERM
while
True
:
# 1) Poll the input queue until there is work to do.
while
not
self
.
scheduler
.
has_
unfinished_
requests
():
while
not
self
.
scheduler
.
has_requests
():
logger
.
debug
(
"EngineCore busy loop waiting."
)
req
=
self
.
input_queue
.
get
()
self
.
_handle_client_request
(
*
req
)
...
...
vllm/v1/metrics/loggers.py
View file @
8ed5421a
...
...
@@ -22,7 +22,7 @@ class StatLoggerBase(ABC):
@
abstractmethod
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
iteration_stats
:
IterationStats
):
iteration_stats
:
Optional
[
IterationStats
]
):
...
def
log
(
self
):
# noqa
...
...
@@ -56,10 +56,11 @@ class LoggingStatLogger(StatLoggerBase):
return
float
(
np
.
sum
(
tracked_stats
)
/
(
now
-
self
.
last_log_time
))
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
iteration_stats
:
IterationStats
):
iteration_stats
:
Optional
[
IterationStats
]
):
"""Log Stats to standard output."""
self
.
_track_iteration_stats
(
iteration_stats
)
if
iteration_stats
:
self
.
_track_iteration_stats
(
iteration_stats
)
self
.
prefix_caching_metrics
.
observe
(
scheduler_stats
.
prefix_cache_stats
)
...
...
@@ -319,7 +320,7 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge
.
set
(
1
)
def
record
(
self
,
scheduler_stats
:
SchedulerStats
,
iteration_stats
:
IterationStats
):
iteration_stats
:
Optional
[
IterationStats
]
):
"""Log to prometheus."""
self
.
gauge_scheduler_running
.
set
(
scheduler_stats
.
num_running_reqs
)
self
.
gauge_scheduler_waiting
.
set
(
scheduler_stats
.
num_waiting_reqs
)
...
...
@@ -331,6 +332,9 @@ class PrometheusStatLogger(StatLoggerBase):
self
.
counter_gpu_prefix_cache_hits
.
inc
(
scheduler_stats
.
prefix_cache_stats
.
hits
)
if
iteration_stats
is
None
:
return
self
.
counter_num_preempted_reqs
.
inc
(
iteration_stats
.
num_preempted_reqs
)
self
.
counter_prompt_tokens
.
inc
(
iteration_stats
.
num_prompt_tokens
)
self
.
counter_generation_tokens
.
inc
(
...
...
vllm/v1/outputs.py
View file @
8ed5421a
...
...
@@ -80,3 +80,13 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict
:
dict
[
str
,
Optional
[
LogprobsTensors
]]
EMPTY_MODEL_RUNNER_OUTPUT
=
ModelRunnerOutput
(
req_ids
=
[],
req_id_to_index
=
{},
sampled_token_ids
=
[],
spec_token_ids
=
None
,
logprobs
=
None
,
prompt_logprobs_dict
=
{},
)
vllm/v1/worker/gpu_model_runner.py
View file @
8ed5421a
...
...
@@ -32,7 +32,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from
vllm.v1.engine.mm_input_cache
import
MMInputCacheClient
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.rejection_sampler
import
INVALID_TOKEN_ID
,
RejectionSampler
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
...
...
@@ -919,6 +920,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
torch
.
Tensor
]:
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
# Return empty ModelRunnerOuptut if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
...
...
@@ -1069,7 +1073,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids
=
self
.
generate_draft_token_ids
(
valid_sampled_token_ids
)
model_runner_output
=
ModelRunnerOutput
(
return
ModelRunnerOutput
(
req_ids
=
self
.
input_batch
.
req_ids
,
req_id_to_index
=
self
.
input_batch
.
req_id_to_index
,
sampled_token_ids
=
valid_sampled_token_ids
,
...
...
@@ -1077,7 +1081,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs
=
logprobs_lists
,
prompt_logprobs_dict
=
prompt_logprobs_dict
,
)
return
model_runner_output
def
generate_draft_token_ids
(
self
,
...
...
vllm/v1/worker/tpu_model_runner.py
View file @
8ed5421a
...
...
@@ -29,7 +29,8 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
)
from
vllm.v1.outputs
import
LogprobsTensors
,
ModelRunnerOutput
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
from
vllm.v1.utils
import
bind_kv_cache
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
...
...
@@ -546,6 +547,9 @@ class TPUModelRunner:
)
->
ModelRunnerOutput
:
# Update cached state
self
.
_update_states
(
scheduler_output
)
if
not
scheduler_output
.
total_num_scheduled_tokens
:
# Return empty ModelRunnerOuptut if there's no work to do.
return
EMPTY_MODEL_RUNNER_OUTPUT
if
self
.
is_multimodal_model
:
# Run the multimodal encoder if any.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment