Unverified Commit 7a7929ab authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Implement preemption via recomputation & Refactor scheduling logic (#12)

parent 88c0268a
...@@ -84,8 +84,9 @@ class FastAPIFrontend: ...@@ -84,8 +84,9 @@ class FastAPIFrontend:
seq = Sequence(seq_id, token_ids, block_size=self.block_size) seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq) seqs.append(seq)
arrival_time = time.time()
group_id = next(self.seq_group_counter) group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs) seq_group = SequenceGroup(group_id, seqs, arrival_time)
group_event = asyncio.Event() group_event = asyncio.Event()
self.sequence_group_events[group_id] = group_event self.sequence_group_events[group_id] = group_event
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)]) await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
......
...@@ -76,7 +76,8 @@ class BlockSpaceManager: ...@@ -76,7 +76,8 @@ class BlockSpaceManager:
self.block_tables: Dict[int, BlockTable] = {} self.block_tables: Dict[int, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> bool: def can_allocate(self, seq_group: SequenceGroup) -> bool:
# NOTE: Here we assume that all sequences in the group have the same prompt. # FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
seq = seq_group.seqs[0] seq = seq_group.seqs[0]
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
......
from typing import List
from cacheflow.sequence import SequenceGroup
class Policy:
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
raise NotImplementedError
def sort_by_priority(
self,
now: float,
seq_groups: List[SequenceGroup],
) -> List[SequenceGroup]:
return sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
)
class FCFS(Policy):
def get_priority(
self,
now: float,
seq_group: SequenceGroup,
) -> float:
return now - seq_group.arrival_time
class PolicyFactory:
_POLICY_REGISTRY = {
'fcfs': FCFS,
}
@classmethod
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
return cls._POLICY_REGISTRY[policy_name](**kwargs)
from typing import Dict, List, Tuple import enum
import time
from typing import Dict, List, Optional, Tuple
from cacheflow.master.block_manager import BlockSpaceManager from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.policy import PolicyFactory
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
...@@ -9,6 +12,19 @@ from cacheflow.sequence import SequenceOutputs ...@@ -9,6 +12,19 @@ from cacheflow.sequence import SequenceOutputs
from cacheflow.sequence import SequenceStatus from cacheflow.sequence import SequenceStatus
class PreemptionMode(enum.Enum):
"""Preemption modes.
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
and swap them back in when the sequences are resumed.
2. Recomputation: Discard the blocks of the preempted sequences and
recompute them when the sequences are resumed, treating the sequences as
new prompts.
"""
SWAP = enum.auto()
RECOMPUTE = enum.auto()
class Scheduler: class Scheduler:
def __init__( def __init__(
...@@ -25,6 +41,8 @@ class Scheduler: ...@@ -25,6 +41,8 @@ class Scheduler:
self.num_cpu_blocks = num_cpu_blocks self.num_cpu_blocks = num_cpu_blocks
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
# Instantiate the scheduling policy.
self.policy = PolicyFactory.get_policy(policy_name='fcfs')
# Create the block space manager. # Create the block space manager.
self.block_manager = BlockSpaceManager( self.block_manager = BlockSpaceManager(
block_size=block_size, block_size=block_size,
...@@ -32,158 +50,140 @@ class Scheduler: ...@@ -32,158 +50,140 @@ class Scheduler:
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
) )
# Running sequence groups (FIFO). # Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = []
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = [] self.running: List[SequenceGroup] = []
# Mapping: group_id -> num_steps. # Mapping: group_id -> num_steps.
self.num_steps: Dict[int, int] = {} self.num_steps: Dict[int, int] = {}
# Mapping: group_id -> sampling params. # Mapping: group_id -> sampling params.
self.sampling_params: Dict[int, SamplingParams] = {} self.sampling_params: Dict[int, SamplingParams] = {}
# Sequence groups in the SWAPPED state.
# Swapped sequence groups (LIFO).
self.swapped: List[SequenceGroup] = [] self.swapped: List[SequenceGroup] = []
# Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = []
def add_sequence_groups( def add_sequence_groups(
self, self,
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]], seq_groups: List[Tuple[SequenceGroup, SamplingParams]],
) -> None: ) -> None:
# Add sequence groups to the pending queue. # Add sequence groups to the waiting queue.
for seq_group, sampling_params in sequence_groups: for seq_group, sampling_params in seq_groups:
self.pending.append(seq_group) self.waiting.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params self.sampling_params[seq_group.group_id] = sampling_params
def _free_seq(self, seq: Sequence) -> None: def _schedule(
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)
# FIXME(woosuk): Support interactive generation.
self.num_steps[seq_group.group_id] = 0
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED:
continue
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
def _swap_in(
self, self,
seq_group: SequenceGroup, ) -> Tuple[Dict[int, int], Dict[int, int], Dict[int, List[int]], List[int]]:
blocks_to_swap_in: Dict[int, int],
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
self.running.append(seq_group)
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)
def step(self) -> List[SequenceGroup]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, List[int]] = {} blocks_to_copy: Dict[int, List[int]] = {}
# 1. Reserve new slots for the running sequences. # Fix the current time.
# NOTE: Here we implicitly assume FCFS scheduling. now = time.time()
# That is, the most recently added sequence group is the first
# to be swapped out. # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
victim_idx = len(self.running) - 1 # in order to minimize the preemption overheads.
for i, seq_group in enumerate(self.running): # Preemption happens only when there is no available slot to keep all
if i > victim_idx: # the sequence groups in the RUNNING state.
# The i-th sequence group has already been swapped out. # In this case, the policy is responsible for deciding which sequence
break # groups to preempt.
# OOM. Swap out the victim sequence groups. self.running = self.policy.sort_by_priority(now, self.running)
# Reserve new token slots for the running sequence groups.
running: List[SequenceGroup] = []
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
while not self.block_manager.can_append(seq_group): while not self.block_manager.can_append(seq_group):
victim_seq_group = self.running[victim_idx] if self.running:
self._swap_out(victim_seq_group, blocks_to_swap_out) # Preempt the lowest-priority sequence groups.
victim_idx -= 1 victim_seq_group = self.running.pop(-1)
if i > victim_idx: self._preempt(victim_seq_group, blocks_to_swap_out)
# No other sequence groups can be swapped out. preempted.append(victim_seq_group)
else:
# No other sequence groups can be preempted.
# Preempt the current sequence group.
self._preempt(seq_group, blocks_to_swap_out)
preempted.append(seq_group)
break break
else: else:
# Append new slots to the sequence group.
self._append(seq_group, blocks_to_copy) self._append(seq_group, blocks_to_copy)
self.running = self.running[:victim_idx + 1] running.append(seq_group)
self.running = running
# 2. Swap in the swapped sequences if possible. # Swap in the sequence groups in the SWAPPED state if possible.
# NOTE: Here we implicitly assume FCFS scheduling. self.swapped = self.policy.sort_by_priority(now, self.swapped)
# The swapped sequences are in LIFO order. while self.swapped:
for i, seq_group in enumerate(reversed(self.swapped)): seq_group = self.swapped[0]
if self.block_manager.can_swap_in(seq_group): # If the sequence group has been preempted in this step, stop.
self._swap_in(seq_group, blocks_to_swap_in) if seq_group in preempted:
self._append(seq_group, blocks_to_copy) break
else: # If the sequence group cannot be swapped in, stop.
# OOM. Stop swapping. if not self.block_manager.can_swap_in(seq_group):
self.swapped = self.swapped[:len(self.swapped) - i]
break break
else:
# All swapped sequences are swapped in.
self.swapped.clear()
# Ensure that swap-in and swap-out never happen at the same timestep. seq_group = self.swapped.pop(0)
if blocks_to_swap_in: self._swap_in(seq_group, blocks_to_swap_in)
assert not blocks_to_swap_out self._append(seq_group, blocks_to_copy)
self.running.append(seq_group)
num_batched_tokens = sum( num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING) seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running for seq_group in self.running
) )
# 3. Join new sequences if possible. # Join waiting sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling. prompt_group_ids: List[int] = []
# TODO(woosuk): Add a batching policy to control the batch size. # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
# prioritized over the sequence groups in the WAITING state.
# This is because we want to bound the amount of CPU memory taken by
# the swapped sequence groups.
if not self.swapped: if not self.swapped:
for i, seq_group in enumerate(self.pending): self.waiting = self.policy.sort_by_priority(now, self.waiting)
while self.waiting:
seq_group = self.waiting[0]
# If the sequence group has been preempted in this step, stop.
if seq_group in preempted:
break
# If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group):
break
# If the number of batched tokens exceeds the limit, stop.
num_prompt_tokens = seq_group.seqs[0].get_len() num_prompt_tokens = seq_group.seqs[0].get_len()
if self.block_manager.can_allocate(seq_group):
if (num_batched_tokens + num_prompt_tokens if (num_batched_tokens + num_prompt_tokens
<= self.max_num_batched_tokens): > self.max_num_batched_tokens):
break
seq_group = self.waiting.pop(0)
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group)
num_batched_tokens += num_prompt_tokens num_batched_tokens += num_prompt_tokens
continue prompt_group_ids.append(seq_group.group_id)
self.pending = self.pending[i:] return (blocks_to_swap_in,
break blocks_to_swap_out,
else: blocks_to_copy,
self.pending.clear() prompt_group_ids)
# 4. Create input data structures. def step(self) -> List[SequenceGroup]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_output = self._schedule()
blocks_to_swap_in = scheduler_output[0]
blocks_to_swap_out = scheduler_output[1]
blocks_to_copy = scheduler_output[2]
prompt_group_ids = scheduler_output[3]
# Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = [] input_seq_groups: List[SequenceGroupInputs] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy() updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running: for seq_group in self.running:
group_id = seq_group.group_id group_id = seq_group.group_id
num_steps = self.num_steps[group_id] is_prompt = group_id in prompt_group_ids
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt = num_steps == 0
input_tokens: Dict[int, List[int]] = {} input_tokens: Dict[int, List[int]] = {}
seq_logprobs: Dict[int, float] = {} seq_logprobs: Dict[int, float] = {}
...@@ -211,13 +211,15 @@ class Scheduler: ...@@ -211,13 +211,15 @@ class Scheduler:
) )
input_seq_groups.append(input_seq_group) input_seq_groups.append(input_seq_group)
# 5. Execute the first stage of the pipeline. # Execute the first stage of the pipeline.
if (input_seq_groups or blocks_to_swap_in or blocks_to_swap_out): if input_seq_groups or blocks_to_swap_in or blocks_to_swap_out:
# Swap in and swap out should never happen at the same time.
assert not (blocks_to_swap_in and blocks_to_swap_out)
self.controllers[0].execute_stage( self.controllers[0].execute_stage(
input_seq_groups, input_seq_groups,
blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy, blocks_to_copy=blocks_to_copy,
) )
return updated_seq_groups return updated_seq_groups
...@@ -276,7 +278,106 @@ class Scheduler: ...@@ -276,7 +278,106 @@ class Scheduler:
running.append(seq_group) running.append(seq_group)
self.running = running self.running = running
def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group)
for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING
# FIXME(woosuk): Support interactive generation.
if seq_group.group_id not in self.num_steps:
self.num_steps[seq_group.group_id] = 0
def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, List[int]],
) -> None:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
ret = self.block_manager.append(seq)
if ret is not None:
src_block, dst_block = ret
if src_block in blocks_to_copy:
blocks_to_copy[src_block].append(dst_block)
else:
blocks_to_copy[src_block] = [dst_block]
def _preempt(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
preemption_mode: Optional[PreemptionMode] = None,
) -> None:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not supported. In such a case,
# we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized
# over sequence groups with a single sequence.
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
if preemption_mode is None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
if len(seqs) == 1:
preemption_mode = PreemptionMode.RECOMPUTE
else:
preemption_mode = PreemptionMode.SWAP
if preemption_mode == PreemptionMode.RECOMPUTE:
self._preempt_by_recompute(seq_group)
elif preemption_mode == PreemptionMode.SWAP:
self._preempt_by_swap(seq_group, blocks_to_swap_out)
else:
assert False, 'Invalid preemption mode.'
def _preempt_by_recompute(
self,
seq_group: SequenceGroup,
) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
assert len(seqs) == 1
for seq in seqs:
seq.status = SequenceStatus.WAITING
self.block_manager.free(seq)
self.waiting.append(seq_group)
def _preempt_by_swap(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
for seq in seqs:
seq.status = SequenceStatus.SWAPPED
self._swap_out(seq_group, blocks_to_swap_out)
self.swapped.append(seq_group)
def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq)
def _free_seq_group(self, seq_group: SequenceGroup) -> None: def _free_seq_group(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id group_id = seq_group.group_id
del self.num_steps[group_id] del self.num_steps[group_id]
del self.sampling_params[group_id] del self.sampling_params[group_id]
def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int],
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
...@@ -10,6 +10,7 @@ from cacheflow.worker.controller import Controller, DeviceID ...@@ -10,6 +10,7 @@ from cacheflow.worker.controller import Controller, DeviceID
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams from cacheflow.sampling_params import SamplingParams
class Server: class Server:
def __init__( def __init__(
self, self,
...@@ -91,7 +92,7 @@ class Server: ...@@ -91,7 +92,7 @@ class Server:
return self.scheduler.step() return self.scheduler.step()
def has_unfinished_requests(self): def has_unfinished_requests(self):
return (self.scheduler.pending or self.scheduler.running or return (self.scheduler.waiting or self.scheduler.running or
self.scheduler.swapped) self.scheduler.swapped)
......
import time
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from transformers import AutoTokenizer from transformers import AutoTokenizer
...@@ -39,6 +40,7 @@ class SimpleFrontend: ...@@ -39,6 +40,7 @@ class SimpleFrontend:
token_ids: List[int], token_ids: List[int],
sampling_params: SamplingParams, sampling_params: SamplingParams,
) -> None: ) -> None:
arrival_time = time.time()
seqs: List[Sequence] = [] seqs: List[Sequence] = []
for _ in range(sampling_params.n): for _ in range(sampling_params.n):
seq_id = next(self.seq_counter) seq_id = next(self.seq_counter)
...@@ -46,7 +48,7 @@ class SimpleFrontend: ...@@ -46,7 +48,7 @@ class SimpleFrontend:
seqs.append(seq) seqs.append(seq)
group_id = next(self.seq_group_counter) group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs) seq_group = SequenceGroup(group_id, seqs, arrival_time)
self.inputs.append((seq_group, sampling_params)) self.inputs.append((seq_group, sampling_params))
def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]: def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]:
......
...@@ -7,7 +7,7 @@ from cacheflow.sampling_params import SamplingParams ...@@ -7,7 +7,7 @@ from cacheflow.sampling_params import SamplingParams
class SequenceStatus(enum.Enum): class SequenceStatus(enum.Enum):
PENDING = enum.auto() WAITING = enum.auto()
RUNNING = enum.auto() RUNNING = enum.auto()
SWAPPED = enum.auto() SWAPPED = enum.auto()
FINISHED = enum.auto() FINISHED = enum.auto()
...@@ -28,7 +28,7 @@ class Sequence: ...@@ -28,7 +28,7 @@ class Sequence:
# Initialize the logical token blocks with the given token ids. # Initialize the logical token blocks with the given token ids.
self.add(token_ids) self.add(token_ids)
self.status = SequenceStatus.PENDING self.status = SequenceStatus.WAITING
self.output_logprobs: List[Dict[int, float]] = [] self.output_logprobs: List[Dict[int, float]] = []
self.cumulative_logprobs = 0.0 self.cumulative_logprobs = 0.0
...@@ -88,9 +88,11 @@ class SequenceGroup: ...@@ -88,9 +88,11 @@ class SequenceGroup:
self, self,
group_id: int, group_id: int,
seqs: List[Sequence], seqs: List[Sequence],
arrival_time: float,
) -> None: ) -> None:
self.group_id = group_id self.group_id = group_id
self.seqs = seqs self.seqs = seqs
self.arrival_time = arrival_time
def get_seqs( def get_seqs(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment