Commit 1ff7856a authored by MaYuhang's avatar MaYuhang Committed by Ceng23333
Browse files

scheduler: add block-based admission check before moving requests to running

parent 338b35f5
......@@ -261,6 +261,12 @@ class BlockManager:
def get_num_free_blocks(self) -> int:
return len(self.free_block_ids)
def get_total_usable_blocks(self) -> int:
freeable_used_blocks = sum(
1 for bid in self.used_block_ids if self.blocks[bid].ref_count == 0
)
return len(self.free_block_ids) + freeable_used_blocks
def __repr__(self):
return (
f"BlockManager(blocks={self.num_blocks}, block_size={self.block_size}, "
......
......@@ -154,6 +154,11 @@ class Scheduler:
req = self.waiting_queue.sync_q.get_nowait()
except queue.Empty:
break
if not self.can_accept_request(req):
self.waiting_queue.sync_q.put(req)
break
# Skip requests that were already finished (e.g., timed out/canceled while waiting)
if req.is_finished():
self.complete_requests([req])
......@@ -164,7 +169,7 @@ class Scheduler:
if not self.cache_manager.can_allocate(num_required_blocks):
if not self.cache_manager.try_free_blocks(num_required_blocks):
raise RuntimeError("No available cache blocks")
raise RuntimeError("No available cache blocks for new request")
# Allocate blocks with automatic prefix caching support
req.block_table, req.slot_mapping, req.num_cached_tokens = (
......@@ -205,7 +210,7 @@ class Scheduler:
scheduled_requests.append(req)
except RuntimeError as e:
raise RuntimeError("No available cache blocks") from e
raise RuntimeError("No available cache blocks for new token") from e
# Return decode batch if any running requests were scheduled
if scheduled_requests:
......@@ -245,6 +250,31 @@ class Scheduler:
# Still running, put back in running queue
self.running_queue.sync_q.put(req)
def can_accept_request(self, request: InferenceRequest) -> bool:
total_required_blocks = 0
# Calculate blocks needed for running requests
running_queue_size = self.running_queue.sync_q.qsize()
for _ in range(running_queue_size):
req = self.running_queue.sync_q.get()
remaining_tokens = (
req.sampling_params.max_tokens - req.get_num_generated_tokens()
)
num_blocks_needed = (
remaining_tokens + self.block_size - 1
) // self.block_size
total_required_blocks += num_blocks_needed
self.running_queue.sync_q.put(req)
# Calculate blocks needed for the new request
total_length = request.get_prompt_length()
total_length += request.sampling_params.max_tokens
num_blocks_needed = (total_length + self.block_size - 1) // self.block_size
total_required_blocks += num_blocks_needed
# Compare with total usable blocks in cache manager
return total_required_blocks <= self.cache_manager.get_total_usable_blocks()
def get_cache_stats(self) -> dict:
"""Get cache statistics."""
return {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment