Unverified Commit c8a7e932 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core][scheduler] simplify and improve scheduler (#6867)

parent 3c10591e
......@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"num_gpu_blocks_override": 2 * (16 + 1),
"num_gpu_blocks_override": 2 * (16 + 2),
}
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
......
This diff is collapsed.
from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError
def sort_by_priority(
self,
now: float,
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.metrics.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {'fcfs': FCFS}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
......@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
......@@ -345,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue.
self.waiting.append(seq_group)
def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the running queue.
# Only for testing purposes.
self.running.append(seq_group)
def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self.swapped.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a sequence group with the given ID.
......@@ -398,32 +407,26 @@ class Scheduler:
def _schedule_running(
self,
running_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerRunningOutputs]:
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests.
Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining running queue (should be always 0) after
scheduling and SchedulerRunningOutputs.
SchedulerRunningOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = []
......@@ -436,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt.
now = time.time()
running_queue = policy.sort_by_priority(now, running_queue)
running_queue = self.running
while running_queue:
seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens(
......@@ -503,7 +505,7 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id)
return running_queue, SchedulerRunningOutputs(
return SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
preempted=preempted,
......@@ -515,12 +517,10 @@ class Scheduler:
def _schedule_swapped(
self,
swapped_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerSwappedInOutputs]:
) -> SchedulerSwappedInOutputs:
"""Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and
......@@ -528,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
......@@ -549,10 +545,10 @@ class Scheduler:
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
infeasible_seq_groups: List[SequenceGroup] = []
swapped_queue = self.swapped
leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]
......@@ -617,7 +613,7 @@ class Scheduler:
swapped_queue.extendleft(leftover_swapped)
return swapped_queue, SchedulerSwappedInOutputs(
return SchedulerSwappedInOutputs(
decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups,
blocks_to_swap_in=blocks_to_swap_in,
......@@ -644,11 +640,10 @@ class Scheduler:
def _schedule_prefills(
self,
waiting_queue: deque,
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
) -> Tuple[deque, SchedulerPrefillOutputs]:
) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
......@@ -660,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups.
Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is
......@@ -672,14 +665,12 @@ class Scheduler:
all tokens.
Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs.
"""
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = []
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])
waiting_queue = self.waiting
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
......@@ -758,7 +749,7 @@ class Scheduler:
if len(seq_groups) > 0:
self.prev_prompt = True
return waiting_queue, SchedulerPrefillOutputs(
return SchedulerPrefillOutputs(
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
......@@ -785,53 +776,43 @@ class Scheduler:
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
remaining_running, running_scheduled = (
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# If any requests are swapped, prioritized swapped requests.
if not self.swapped:
remaining_waiting, prefills = self._schedule_prefills(
self.waiting, budget, curr_loras, enable_chunking=False)
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=False)
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
# Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills.
if len(prefills.seq_groups) == 0:
remaining_running, running_scheduled = self._schedule_running(
self.running,
budget,
curr_loras,
fcfs_policy,
enable_chunking=False)
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=False)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped(
self.swapped, budget, curr_loras, fcfs_policy)
swapped_in = self._schedule_swapped(budget, curr_loras)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out))
......@@ -877,42 +858,32 @@ class Scheduler:
)
curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
remaining_running, running_scheduled = (
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
prefills = SchedulerPrefillOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
# Decoding should be always scheduled first by fcfs.
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
remaining_running, running_scheduled = self._schedule_running(
self.running,
budget,
curr_loras,
fcfs_policy,
enable_chunking=True)
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=True)
# Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in.
if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped(
self.swapped, budget, curr_loras, fcfs_policy)
swapped_in = self._schedule_swapped(budget, curr_loras)
# Schedule new prefills.
remaining_waiting, prefills = self._schedule_prefills(
self.waiting, budget, curr_loras, enable_chunking=True)
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=True)
assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups])
......@@ -923,7 +894,6 @@ class Scheduler:
self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups])
# Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment