Unverified Commit 8ed5421a authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[V1] Eagerly remove finished requests from the batch (#14388)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent c6359e8c
...@@ -102,14 +102,24 @@ def test_engine_core(monkeypatch): ...@@ -102,14 +102,24 @@ def test_engine_core(monkeypatch):
engine_core.add_request(req) engine_core.add_request(req)
assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
assert engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()
_ = engine_core.step() _ = engine_core.step()
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 1 assert len(engine_core.scheduler.running) == 1
assert engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()
engine_core.abort_requests([request_id]) engine_core.abort_requests([request_id])
assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0 assert len(engine_core.scheduler.running) == 0
assert not engine_core.scheduler.has_unfinished_requests()
assert engine_core.scheduler.has_finished_requests()
_ = engine_core.step()
assert not engine_core.scheduler.has_unfinished_requests()
assert not engine_core.scheduler.has_finished_requests()
# Add, step, abort 1 of the 3. # Add, step, abort 1 of the 3.
req0 = make_request() req0 = make_request()
......
...@@ -50,7 +50,7 @@ def loop_until_done(client: EngineCoreClient, outputs: dict): ...@@ -50,7 +50,7 @@ def loop_until_done(client: EngineCoreClient, outputs: dict):
engine_core_outputs = client.get_output().outputs engine_core_outputs = client.get_output().outputs
if len(engine_core_outputs) == 0: if len(engine_core_outputs) == 0:
break continue
all_finished = True all_finished = True
for out in engine_core_outputs: for out in engine_core_outputs:
...@@ -68,7 +68,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict): ...@@ -68,7 +68,7 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
engine_core_outputs = (await client.get_output_async()).outputs engine_core_outputs = (await client.get_output_async()).outputs
if len(engine_core_outputs) == 0: if len(engine_core_outputs) == 0:
break continue
all_finished = True all_finished = True
for out in engine_core_outputs: for out in engine_core_outputs:
......
...@@ -682,6 +682,7 @@ class Scheduler: ...@@ -682,6 +682,7 @@ class Scheduler:
assert RequestStatus.is_finished(finished_status) assert RequestStatus.is_finished(finished_status)
if isinstance(request_ids, str): if isinstance(request_ids, str):
request_ids = (request_ids, ) request_ids = (request_ids, )
else:
request_ids = set(request_ids) request_ids = set(request_ids)
for req_id in request_ids: for req_id in request_ids:
...@@ -714,6 +715,14 @@ class Scheduler: ...@@ -714,6 +715,14 @@ class Scheduler:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0 return self.get_num_unfinished_requests() > 0
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0
def has_requests(self):
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
def get_num_unscheduled_requests(self) -> int: def get_num_unscheduled_requests(self) -> int:
"""Number of requests that are not being processed by the executor.""" """Number of requests that are not being processed by the executor."""
return self.get_num_unfinished_requests() - len(self.scheduled_req_ids) return self.get_num_unfinished_requests() - len(self.scheduled_req_ids)
......
...@@ -253,13 +253,14 @@ class AsyncLLM(EngineClient): ...@@ -253,13 +253,14 @@ class AsyncLLM(EngineClient):
while True: while True:
# 1) Pull EngineCoreOutputs from the EngineCore. # 1) Pull EngineCoreOutputs from the EngineCore.
outputs = await self.engine_core.get_output_async() outputs = await self.engine_core.get_output_async()
num_outputs = len(outputs.outputs)
iteration_stats = IterationStats() if self.log_stats else None iteration_stats = IterationStats() if (
self.log_stats and num_outputs) else None
# Split outputs into chunks of at most # Split outputs into chunks of at most
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the # VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
# event loop for too long. # event loop for too long.
num_outputs = len(outputs.outputs)
if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: if num_outputs <= VLLM_V1_OUTPUT_PROC_CHUNK_SIZE:
slices = (outputs.outputs, ) slices = (outputs.outputs, )
else: else:
...@@ -313,7 +314,6 @@ class AsyncLLM(EngineClient): ...@@ -313,7 +314,6 @@ class AsyncLLM(EngineClient):
return return
assert scheduler_stats is not None assert scheduler_stats is not None
assert iteration_stats is not None
for stat_logger in self.stat_loggers: for stat_logger in self.stat_loggers:
stat_logger.record(scheduler_stats=scheduler_stats, stat_logger.record(scheduler_stats=scheduler_stats,
iteration_stats=iteration_stats) iteration_stats=iteration_stats)
......
...@@ -153,7 +153,9 @@ class EngineCore: ...@@ -153,7 +153,9 @@ class EngineCore:
def step(self) -> EngineCoreOutputs: def step(self) -> EngineCoreOutputs:
"""Schedule, execute, and make output.""" """Schedule, execute, and make output."""
if not self.scheduler.has_unfinished_requests(): # Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return EngineCoreOutputs( return EngineCoreOutputs(
outputs=[], outputs=[],
scheduler_stats=self.scheduler.make_stats(), scheduler_stats=self.scheduler.make_stats(),
...@@ -335,7 +337,7 @@ class EngineCoreProc(EngineCore): ...@@ -335,7 +337,7 @@ class EngineCoreProc(EngineCore):
# Loop until process is sent a SIGINT or SIGTERM # Loop until process is sent a SIGINT or SIGTERM
while True: while True:
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
while not self.scheduler.has_unfinished_requests(): while not self.scheduler.has_requests():
logger.debug("EngineCore busy loop waiting.") logger.debug("EngineCore busy loop waiting.")
req = self.input_queue.get() req = self.input_queue.get()
self._handle_client_request(*req) self._handle_client_request(*req)
......
...@@ -22,7 +22,7 @@ class StatLoggerBase(ABC): ...@@ -22,7 +22,7 @@ class StatLoggerBase(ABC):
@abstractmethod @abstractmethod
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats): iteration_stats: Optional[IterationStats]):
... ...
def log(self): # noqa def log(self): # noqa
...@@ -56,9 +56,10 @@ class LoggingStatLogger(StatLoggerBase): ...@@ -56,9 +56,10 @@ class LoggingStatLogger(StatLoggerBase):
return float(np.sum(tracked_stats) / (now - self.last_log_time)) return float(np.sum(tracked_stats) / (now - self.last_log_time))
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats): iteration_stats: Optional[IterationStats]):
"""Log Stats to standard output.""" """Log Stats to standard output."""
if iteration_stats:
self._track_iteration_stats(iteration_stats) self._track_iteration_stats(iteration_stats)
self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats)
...@@ -319,7 +320,7 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -319,7 +320,7 @@ class PrometheusStatLogger(StatLoggerBase):
info_gauge.set(1) info_gauge.set(1)
def record(self, scheduler_stats: SchedulerStats, def record(self, scheduler_stats: SchedulerStats,
iteration_stats: IterationStats): iteration_stats: Optional[IterationStats]):
"""Log to prometheus.""" """Log to prometheus."""
self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs)
self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs)
...@@ -331,6 +332,9 @@ class PrometheusStatLogger(StatLoggerBase): ...@@ -331,6 +332,9 @@ class PrometheusStatLogger(StatLoggerBase):
self.counter_gpu_prefix_cache_hits.inc( self.counter_gpu_prefix_cache_hits.inc(
scheduler_stats.prefix_cache_stats.hits) scheduler_stats.prefix_cache_stats.hits)
if iteration_stats is None:
return
self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs) self.counter_num_preempted_reqs.inc(iteration_stats.num_preempted_reqs)
self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens) self.counter_prompt_tokens.inc(iteration_stats.num_prompt_tokens)
self.counter_generation_tokens.inc( self.counter_generation_tokens.inc(
......
...@@ -80,3 +80,13 @@ class ModelRunnerOutput: ...@@ -80,3 +80,13 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs] # [prompt_len, num_prompt_logprobs]
# [prompt_len] # [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
...@@ -32,7 +32,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget ...@@ -32,7 +32,8 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.ngram_proposer import NgramProposer
...@@ -919,6 +920,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -919,6 +920,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[ModelRunnerOutput, torch.Tensor]: ) -> Union[ModelRunnerOutput, torch.Tensor]:
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
...@@ -1069,7 +1073,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1069,7 +1073,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids = self.generate_draft_token_ids( spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids) valid_sampled_token_ids)
model_runner_output = ModelRunnerOutput( return ModelRunnerOutput(
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids, sampled_token_ids=valid_sampled_token_ids,
...@@ -1077,7 +1081,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1077,7 +1081,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
) )
return model_runner_output
def generate_draft_token_ids( def generate_draft_token_ids(
self, self,
......
...@@ -29,7 +29,8 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, ...@@ -29,7 +29,8 @@ from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec) KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.utils import bind_kv_cache from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...@@ -546,6 +547,9 @@ class TPUModelRunner: ...@@ -546,6 +547,9 @@ class TPUModelRunner:
) -> ModelRunnerOutput: ) -> ModelRunnerOutput:
# Update cached state # Update cached state
self._update_states(scheduler_output) self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
# Return empty ModelRunnerOuptut if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if self.is_multimodal_model: if self.is_multimodal_model:
# Run the multimodal encoder if any. # Run the multimodal encoder if any.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment