Unverified Commit dc78c2c9 authored by Dan Blanaru's avatar Dan Blanaru Committed by GitHub
Browse files

[Core] add option to schedule requests based on full ISL (#37307)


Signed-off-by: default avatarDan Blanaru <48605845+DanBlanaru@users.noreply.github.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
parent 47318847
......@@ -135,6 +135,12 @@ class SchedulerConfig:
and starting configuration.
"""
scheduler_reserve_full_isl: bool = True
"""If True, the scheduler checks whether the full input sequence length
fits in the KV cache before admitting a new request, rather than only
checking the first chunk. Prevents over-admission and KV cache thrashing
with chunked prefill."""
async_scheduling: bool | None = Field(default=None)
"""If set to False, disable async scheduling. Async scheduling helps to
avoid gaps in GPU utilization, leading to better latency and throughput.
......
......@@ -531,6 +531,8 @@ class EngineArgs:
enable_chunked_prefill: bool | None = None
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
scheduler_reserve_full_isl: bool = SchedulerConfig.scheduler_reserve_full_isl
disable_hybrid_kv_cache_manager: bool | None = (
SchedulerConfig.disable_hybrid_kv_cache_manager
)
......@@ -1234,6 +1236,10 @@ class EngineArgs:
scheduler_group.add_argument(
"--scheduler-cls", **scheduler_kwargs["scheduler_cls"]
)
scheduler_group.add_argument(
"--scheduler-reserve-full-isl",
**scheduler_kwargs["scheduler_reserve_full_isl"],
)
scheduler_group.add_argument(
"--disable-hybrid-kv-cache-manager",
**scheduler_kwargs["disable_hybrid_kv_cache_manager"],
......@@ -1810,6 +1816,7 @@ class EngineArgs:
max_num_partial_prefills=self.max_num_partial_prefills,
max_long_partial_prefills=self.max_long_partial_prefills,
long_prefill_token_threshold=self.long_prefill_token_threshold,
scheduler_reserve_full_isl=self.scheduler_reserve_full_isl,
disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager,
async_scheduling=self.async_scheduling,
stream_interval=self.stream_interval,
......
......@@ -215,6 +215,45 @@ class KVCacheManager:
return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens
def can_fit_full_sequence(
self,
request: Request,
num_new_computed_tokens: int = 0,
new_computed_blocks: KVCacheBlocks | None = None,
num_external_computed_tokens: int = 0,
num_encoder_tokens: int = 0,
) -> bool:
"""Check if the KV cache has enough free blocks to hold the full
sequence, accounting for prefix cache hits and sliding window.
This is used as an admission gate to prevent over-admitting requests
when chunked prefill would otherwise only check the first chunk.
"""
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = self.empty_kv_cache_blocks.blocks
num_local_computed_tokens = (
request.num_computed_tokens + num_new_computed_tokens
)
total_computed_tokens = min(
num_local_computed_tokens + num_external_computed_tokens,
self.max_model_len,
)
full_num_tokens = min(request.num_tokens, self.max_model_len)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=full_num_tokens,
new_computed_blocks=new_computed_block_list,
num_encoder_tokens=num_encoder_tokens,
total_computed_tokens=total_computed_tokens,
num_tokens_main_model=full_num_tokens,
)
return num_blocks_to_allocate <= self.block_pool.get_num_free_blocks()
def allocate_slots(
self,
request: Request,
......
......@@ -236,6 +236,9 @@ class Scheduler(SchedulerInterface):
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
self.scheduler_reserve_full_isl = (
self.scheduler_config.scheduler_reserve_full_isl
)
self.has_mamba_layers = kv_cache_config.has_mamba_layers
self.needs_kv_cache_zeroing = kv_cache_config.needs_kv_cache_zeroing
......@@ -719,6 +722,20 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule
)
if (
self.scheduler_reserve_full_isl
and not self.kv_cache_manager.can_fit_full_sequence(
request,
num_new_computed_tokens=num_new_local_computed_tokens,
new_computed_blocks=new_computed_blocks,
num_external_computed_tokens=num_external_computed_tokens,
num_encoder_tokens=num_encoder_tokens,
)
):
if request.has_encoder_inputs:
self.encoder_cache_manager.free(request)
break
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment