Commit 53f70e73 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Reduce the number of states in scheduler

parent 762fd1c3
...@@ -43,11 +43,6 @@ class Scheduler: ...@@ -43,11 +43,6 @@ class Scheduler:
# Pending sequence groups (FIFO). # Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = [] self.pending: List[SequenceGroup] = []
# Blocks that need to be swaped or copied before model execution.
self.blocks_to_swap_in: Dict[int, int] = {}
self.blocks_to_swap_out: Dict[int, int] = {}
self.blocks_to_copy: Dict[int, int] = {}
def _free_seq(self, seq: Sequence) -> None: def _free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED seq.status = SequenceStatus.FINISHED
self.block_manager.free(seq) self.block_manager.free(seq)
...@@ -57,36 +52,53 @@ class Scheduler: ...@@ -57,36 +52,53 @@ class Scheduler:
for seq in seq_group.seqs: for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
self.running.append(seq_group) self.running.append(seq_group)
# FIXME # FIXME(woosuk): Support interactive generation.
self.num_steps[seq_group.group_id] = 0 self.num_steps[seq_group.group_id] = 0
def _append(self, seq_group: SequenceGroup) -> None: def _append(
self,
seq_group: SequenceGroup,
blocks_to_copy: Dict[int, int],
) -> None:
for seq in seq_group.seqs: for seq in seq_group.seqs:
if seq.status == SequenceStatus.FINISHED: if seq.status == SequenceStatus.FINISHED:
continue continue
ret = self.block_manager.append(seq) ret = self.block_manager.append(seq)
if ret is not None: if ret is not None:
src_block, dst_block = ret src_block, dst_block = ret
self.blocks_to_copy[src_block] = dst_block blocks_to_copy[src_block] = dst_block
def _swap_in(self, seq_group: SequenceGroup) -> None: def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int],
) -> None:
mapping = self.block_manager.swap_in(seq_group) mapping = self.block_manager.swap_in(seq_group)
self.blocks_to_swap_in.update(mapping) blocks_to_swap_in.update(mapping)
for seq in seq_group.seqs: for seq in seq_group.seqs:
if seq.status == SequenceStatus.SWAPPED: if seq.status == SequenceStatus.SWAPPED:
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
self.running.append(seq_group) self.running.append(seq_group)
def _swap_out(self, seq_group: SequenceGroup) -> None: def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
) -> None:
assert self.block_manager.can_swap_out(seq_group) assert self.block_manager.can_swap_out(seq_group)
mapping = self.block_manager.swap_out(seq_group) mapping = self.block_manager.swap_out(seq_group)
self.blocks_to_swap_out.update(mapping) blocks_to_swap_out.update(mapping)
for seq in seq_group.seqs: for seq in seq_group.seqs:
if seq.status == SequenceStatus.RUNNING: if seq.status == SequenceStatus.RUNNING:
seq.status = SequenceStatus.SWAPPED seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group) self.swapped.append(seq_group)
def prepare(self) -> None: def pre_step(self) -> None:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_copy: Dict[int, int] = {}
# 1. Prepare new slots for the running sequences. # 1. Prepare new slots for the running sequences.
# NOTE: Here we implicitly assume FCFS scheduling. # NOTE: Here we implicitly assume FCFS scheduling.
# That is, the most recently added sequence group is the first # That is, the most recently added sequence group is the first
...@@ -99,13 +111,13 @@ class Scheduler: ...@@ -99,13 +111,13 @@ class Scheduler:
# OOM. Swap out the victim sequence groups. # OOM. Swap out the victim sequence groups.
while not self.block_manager.can_append(seq_group): while not self.block_manager.can_append(seq_group):
victim_seq_group = self.running[victim_idx] victim_seq_group = self.running[victim_idx]
self._swap_out(victim_seq_group) self._swap_out(victim_seq_group, blocks_to_swap_out)
victim_idx -= 1 victim_idx -= 1
if i > victim_idx: if i > victim_idx:
# No other sequence groups can be swapped out. # No other sequence groups can be swapped out.
break break
else: else:
self._append(seq_group) self._append(seq_group, blocks_to_copy)
self.running = self.running[:victim_idx + 1] self.running = self.running[:victim_idx + 1]
# 2. Swap in the swapped sequences if possible. # 2. Swap in the swapped sequences if possible.
...@@ -113,8 +125,8 @@ class Scheduler: ...@@ -113,8 +125,8 @@ class Scheduler:
# The swapped sequences are in LIFO order. # The swapped sequences are in LIFO order.
for i, seq_group in enumerate(reversed(self.swapped)): for i, seq_group in enumerate(reversed(self.swapped)):
if self.block_manager.can_swap_in(seq_group): if self.block_manager.can_swap_in(seq_group):
self._swap_in(seq_group) self._swap_in(seq_group, blocks_to_swap_in)
self._append(seq_group) self._append(seq_group, blocks_to_copy)
else: else:
# OOM. Stop swapping. # OOM. Stop swapping.
self.swapped = self.swapped[:len(self.swapped) - i] self.swapped = self.swapped[:len(self.swapped) - i]
...@@ -147,10 +159,18 @@ class Scheduler: ...@@ -147,10 +159,18 @@ class Scheduler:
else: else:
self.pending.clear() self.pending.clear()
def step(self) -> None: # Execute step.
# Ensure that either swap-in or swap-out is performed. self.step(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
if self.blocks_to_swap_in:
assert not self.blocks_to_swap_out def step(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, int],
) -> None:
# Ensure that swap-in and swap-out never happen at the same timestep.
if blocks_to_swap_in:
assert not blocks_to_swap_out
# Create input data structures. # Create input data structures.
prompt_tokens: Dict[int, List[int]] = {} prompt_tokens: Dict[int, List[int]] = {}
...@@ -181,16 +201,11 @@ class Scheduler: ...@@ -181,16 +201,11 @@ class Scheduler:
generation_tokens, generation_tokens,
context_lens, context_lens,
block_tables, block_tables,
self.blocks_to_swap_in.copy(), blocks_to_swap_in,
self.blocks_to_swap_out.copy(), blocks_to_swap_out,
self.blocks_to_copy.copy(), blocks_to_copy,
) )
# Clear for the next step.
self.blocks_to_swap_in.clear()
self.blocks_to_swap_out.clear()
self.blocks_to_copy.clear()
def post_step( def post_step(
self, self,
next_tokens: Dict[int, Tuple[int, int]], next_tokens: Dict[int, Tuple[int, int]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment