Unverified Commit c8a7e932 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core][scheduler] simplify and improve scheduler (#6867)

parent 3c10591e
...@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, ...@@ -183,7 +183,7 @@ def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
# Allow only 2 sequences of ~128 tokens in worst case. # Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size # Note 16 = 128/block_size
"num_gpu_blocks_override": 2 * (16 + 1), "num_gpu_blocks_override": 2 * (16 + 2),
} }
]) ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{ @pytest.mark.parametrize("baseline_llm_kwargs", [{
......
This diff is collapsed.
from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError
def sort_by_priority(
self,
now: float,
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.metrics.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {'fcfs': FCFS}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
...@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union ...@@ -8,7 +8,6 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
...@@ -345,6 +344,16 @@ class Scheduler: ...@@ -345,6 +344,16 @@ class Scheduler:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the running queue.
# Only for testing purposes.
self.running.append(seq_group)
def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the swapped queue.
# Only for testing purposes.
self.swapped.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a sequence group with the given ID. """Aborts a sequence group with the given ID.
...@@ -398,32 +407,26 @@ class Scheduler: ...@@ -398,32 +407,26 @@ class Scheduler:
def _schedule_running( def _schedule_running(
self, self,
running_queue: deque,
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False, enable_chunking: bool = False,
) -> Tuple[deque, SchedulerRunningOutputs]: ) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running. """Schedule sequence groups that are running.
Running queue should include decode and chunked prefill requests. Running queue should include decode and chunked prefill requests.
Args: Args:
running_queue: The queue that contains running requests (i.e.,
decodes). The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated budget: The scheduling budget. The argument is in-place updated
when any decodes are preempted. when any decodes are preempted.
curr_loras: Currently batched lora request ids. The argument is curr_loras: Currently batched lora request ids. The argument is
in-place updated when any decodes are preempted. in-place updated when any decodes are preempted.
policy: The sorting policy to sort running_queue.
enable_chunking: If True, seq group can be chunked and only a enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
Returns: Returns:
A tuple of remaining running queue (should be always 0) after SchedulerRunningOutputs.
scheduling and SchedulerRunningOutputs.
""" """
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_swap_out: List[Tuple[int, int]] = []
...@@ -436,10 +439,9 @@ class Scheduler: ...@@ -436,10 +439,9 @@ class Scheduler:
# NOTE(woosuk): Preemption happens only when there is no available slot # NOTE(woosuk): Preemption happens only when there is no available slot
# to keep all the sequence groups in the RUNNING state. # to keep all the sequence groups in the RUNNING state.
# In this case, the policy is responsible for deciding which sequence
# groups to preempt. running_queue = self.running
now = time.time()
running_queue = policy.sort_by_priority(now, running_queue)
while running_queue: while running_queue:
seq_group = running_queue[0] seq_group = running_queue[0]
num_running_tokens = self._get_num_new_tokens( num_running_tokens = self._get_num_new_tokens(
...@@ -503,7 +505,7 @@ class Scheduler: ...@@ -503,7 +505,7 @@ class Scheduler:
if curr_loras is not None and seq_group.lora_int_id > 0: if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.add(seq_group.lora_int_id) curr_loras.add(seq_group.lora_int_id)
return running_queue, SchedulerRunningOutputs( return SchedulerRunningOutputs(
decode_seq_groups=decode_seq_groups, decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups, prefill_seq_groups=prefill_seq_groups,
preempted=preempted, preempted=preempted,
...@@ -515,12 +517,10 @@ class Scheduler: ...@@ -515,12 +517,10 @@ class Scheduler:
def _schedule_swapped( def _schedule_swapped(
self, self,
swapped_queue: deque,
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
policy: Policy,
enable_chunking: bool = False, enable_chunking: bool = False,
) -> Tuple[deque, SchedulerSwappedInOutputs]: ) -> SchedulerSwappedInOutputs:
"""Schedule sequence groups that are swapped out. """Schedule sequence groups that are swapped out.
It schedules swapped requests as long as it fits `budget` and It schedules swapped requests as long as it fits `budget` and
...@@ -528,20 +528,16 @@ class Scheduler: ...@@ -528,20 +528,16 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups. `budget` and `curr_loras` are updated based on scheduled seq_groups.
Args: Args:
swapped_queue: The queue that contains swapped out requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated budget: The scheduling budget. The argument is in-place updated
when any requests are swapped in. when any requests are swapped in.
curr_loras: Currently batched lora request ids. The argument is curr_loras: Currently batched lora request ids. The argument is
in-place updated when any requests are swapped in. in-place updated when any requests are swapped in.
policy: The sorting policy to sort swapped_queue.
enable_chunking: If True, seq group can be chunked and only a enable_chunking: If True, seq group can be chunked and only a
chunked number of tokens are scheduled if chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule `budget.num_batched_tokens` has not enough capacity to schedule
all tokens. all tokens.
Returns: Returns:
A tuple of remaining swapped_queue after scheduling and
SchedulerSwappedInOutputs. SchedulerSwappedInOutputs.
""" """
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
...@@ -549,10 +545,10 @@ class Scheduler: ...@@ -549,10 +545,10 @@ class Scheduler:
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = [] decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = []
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)
infeasible_seq_groups: List[SequenceGroup] = [] infeasible_seq_groups: List[SequenceGroup] = []
swapped_queue = self.swapped
leftover_swapped: Deque[SequenceGroup] = deque() leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue: while swapped_queue:
seq_group = swapped_queue[0] seq_group = swapped_queue[0]
...@@ -617,7 +613,7 @@ class Scheduler: ...@@ -617,7 +613,7 @@ class Scheduler:
swapped_queue.extendleft(leftover_swapped) swapped_queue.extendleft(leftover_swapped)
return swapped_queue, SchedulerSwappedInOutputs( return SchedulerSwappedInOutputs(
decode_seq_groups=decode_seq_groups, decode_seq_groups=decode_seq_groups,
prefill_seq_groups=prefill_seq_groups, prefill_seq_groups=prefill_seq_groups,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
...@@ -644,11 +640,10 @@ class Scheduler: ...@@ -644,11 +640,10 @@ class Scheduler:
def _schedule_prefills( def _schedule_prefills(
self, self,
waiting_queue: deque,
budget: SchedulingBudget, budget: SchedulingBudget,
curr_loras: Optional[Set[int]], curr_loras: Optional[Set[int]],
enable_chunking: bool = False, enable_chunking: bool = False,
) -> Tuple[deque, SchedulerPrefillOutputs]: ) -> SchedulerPrefillOutputs:
"""Schedule sequence groups that are in prefill stage. """Schedule sequence groups that are in prefill stage.
Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE
...@@ -660,8 +655,6 @@ class Scheduler: ...@@ -660,8 +655,6 @@ class Scheduler:
`budget` and `curr_loras` are updated based on scheduled seq_groups. `budget` and `curr_loras` are updated based on scheduled seq_groups.
Args: Args:
waiting_queue: The queue that contains prefill requests.
The given arguments are NOT in-place modified.
budget: The scheduling budget. The argument is in-place updated budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled. when any requests are scheduled.
curr_loras: Currently batched lora request ids. The argument is curr_loras: Currently batched lora request ids. The argument is
...@@ -672,14 +665,12 @@ class Scheduler: ...@@ -672,14 +665,12 @@ class Scheduler:
all tokens. all tokens.
Returns: Returns:
A tuple of remaining waiting_queue after scheduling and
SchedulerSwappedInOutputs. SchedulerSwappedInOutputs.
""" """
ignored_seq_groups: List[SequenceGroup] = [] ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = [] seq_groups: List[SequenceGroup] = []
# We don't sort waiting queue because we assume it is sorted.
# Copy the queue so that the input queue is not modified. waiting_queue = self.waiting
waiting_queue = deque([s for s in waiting_queue])
leftover_waiting_sequences: Deque[SequenceGroup] = deque() leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue: while self._passed_delay(time.time()) and waiting_queue:
...@@ -758,7 +749,7 @@ class Scheduler: ...@@ -758,7 +749,7 @@ class Scheduler:
if len(seq_groups) > 0: if len(seq_groups) > 0:
self.prev_prompt = True self.prev_prompt = True
return waiting_queue, SchedulerPrefillOutputs( return SchedulerPrefillOutputs(
seq_groups=seq_groups, seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups, ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True)) num_lookahead_slots=self._get_num_lookahead_slots(is_prefill=True))
...@@ -785,53 +776,43 @@ class Scheduler: ...@@ -785,53 +776,43 @@ class Scheduler:
seq_group.lora_int_id for seq_group in self.running seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None if seq_group.lora_int_id > 0) if self.lora_enabled else None
remaining_waiting, prefills = (self.waiting, prefills = SchedulerPrefillOutputs.create_empty()
SchedulerPrefillOutputs.create_empty()) running_scheduled = SchedulerRunningOutputs.create_empty()
remaining_running, running_scheduled = ( swapped_in = SchedulerSwappedInOutputs.create_empty()
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
# If any requests are swapped, prioritized swapped requests. # If any requests are swapped, prioritized swapped requests.
if not self.swapped: if not self.swapped:
remaining_waiting, prefills = self._schedule_prefills( prefills = self._schedule_prefills(budget,
self.waiting, budget, curr_loras, enable_chunking=False) curr_loras,
enable_chunking=False)
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs")
# Don't schedule decodes if prefills are scheduled. # Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills. # only contains decode requests, not chunked prefills.
if len(prefills.seq_groups) == 0: if len(prefills.seq_groups) == 0:
remaining_running, running_scheduled = self._schedule_running( running_scheduled = self._schedule_running(budget,
self.running, curr_loras,
budget, enable_chunking=False)
curr_loras,
fcfs_policy,
enable_chunking=False)
# If any sequence group is preempted, do not swap in any sequence # If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests. # group. because it means there's no slot for new running requests.
if len(running_scheduled.preempted) + len( if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0: running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped( swapped_in = self._schedule_swapped(budget, curr_loras)
self.swapped, budget, curr_loras, fcfs_policy)
assert (budget.num_batched_tokens <= assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests. # Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend( self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups]) [s.seq_group for s in running_scheduled.decode_seq_groups])
self.running.extend( self.running.extend(
[s.seq_group for s in swapped_in.decode_seq_groups]) [s.seq_group for s in swapped_in.decode_seq_groups])
# Update swapped requests. # Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
preempted = (len(running_scheduled.preempted) + preempted = (len(running_scheduled.preempted) +
len(running_scheduled.swapped_out)) len(running_scheduled.swapped_out))
...@@ -877,42 +858,32 @@ class Scheduler: ...@@ -877,42 +858,32 @@ class Scheduler:
) )
curr_loras: Set[int] = set() curr_loras: Set[int] = set()
remaining_waiting, prefills = (self.waiting, prefills = SchedulerPrefillOutputs.create_empty()
SchedulerPrefillOutputs.create_empty()) swapped_in = SchedulerSwappedInOutputs.create_empty()
remaining_running, running_scheduled = (
self.running, SchedulerRunningOutputs.create_empty())
remaining_swapped, swapped_in = (
self.swapped, SchedulerSwappedInOutputs.create_empty())
# Decoding should be always scheduled first by fcfs. # Decoding should be always scheduled first by fcfs.
fcfs_policy = PolicyFactory.get_policy(policy_name="fcfs") running_scheduled = self._schedule_running(budget,
remaining_running, running_scheduled = self._schedule_running( curr_loras,
self.running, enable_chunking=True)
budget,
curr_loras,
fcfs_policy,
enable_chunking=True)
# Schedule swapped out requests. # Schedule swapped out requests.
# If preemption happens, it means we don't have space for swap-in. # If preemption happens, it means we don't have space for swap-in.
if len(running_scheduled.preempted) + len( if len(running_scheduled.preempted) + len(
running_scheduled.swapped_out) == 0: running_scheduled.swapped_out) == 0:
remaining_swapped, swapped_in = self._schedule_swapped( swapped_in = self._schedule_swapped(budget, curr_loras)
self.swapped, budget, curr_loras, fcfs_policy)
# Schedule new prefills. # Schedule new prefills.
remaining_waiting, prefills = self._schedule_prefills( prefills = self._schedule_prefills(budget,
self.waiting, budget, curr_loras, enable_chunking=True) curr_loras,
enable_chunking=True)
assert (budget.num_batched_tokens <= assert (budget.num_batched_tokens <=
self.scheduler_config.max_num_batched_tokens) self.scheduler_config.max_num_batched_tokens)
assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs
# Update waiting requests. # Update waiting requests.
self.waiting = remaining_waiting
self.waiting.extendleft(running_scheduled.preempted) self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests. # Update new running requests.
self.running = remaining_running
self.running.extend([s.seq_group for s in prefills.seq_groups]) self.running.extend([s.seq_group for s in prefills.seq_groups])
self.running.extend( self.running.extend(
[s.seq_group for s in running_scheduled.decode_seq_groups]) [s.seq_group for s in running_scheduled.decode_seq_groups])
...@@ -923,7 +894,6 @@ class Scheduler: ...@@ -923,7 +894,6 @@ class Scheduler:
self.running.extend( self.running.extend(
[s.seq_group for s in swapped_in.prefill_seq_groups]) [s.seq_group for s in swapped_in.prefill_seq_groups])
# Update swapped requests. # Update swapped requests.
self.swapped = remaining_swapped
self.swapped.extend(running_scheduled.swapped_out) self.swapped.extend(running_scheduled.swapped_out)
return SchedulerOutputs( return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups + scheduled_seq_groups=(prefills.seq_groups +
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment