scheduler.py 9.2 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
4
5
6
7
from typing import Dict, List, Tuple

from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus

8
9
_MAX_NUM_BATCHED_TOKENS = 2048

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
12

class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
13
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
14
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
15
        controllers: List,
Woosuk Kwon's avatar
Woosuk Kwon committed
16
17
18
19
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
22
23
24
25
        self.controllers = controllers
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks

        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
31
        self.block_manager = BlockSpaceManager(
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
        )

32
33
        # Running sequence groups (FIFO).
        self.running: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
40
41
42
43
        # Mapping: group_id -> num_steps.
        self.num_steps: Dict[int, int] = {}
        # Mapping: group_id -> max_num_steps.
        self.max_num_steps: Dict[int, int] = {}
        # Mapping: group_id -> stop_token_ids.
        self.stop_token_ids: Dict[int, List[int]] = {}

        # Swapped sequence groups (LIFO).
        self.swapped: List[SequenceGroup] = []
        # Pending sequence groups (FIFO).
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
        self.pending: List[SequenceGroup] = []

        # Blocks that need to be swaped or copied before model execution.
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
        self.blocks_to_swap_in: Dict[int, int] = {}
        self.blocks_to_swap_out: Dict[int, int] = {}
        self.blocks_to_copy: Dict[int, int] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
57
58

    def _free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)

    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
        for seq in seq_group.seqs:
            seq.status = SequenceStatus.RUNNING
59
        self.running.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
        # FIXME
        self.num_steps[seq_group.group_id] = 0
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
66
67
68
69

    def _append(self, seq_group: SequenceGroup) -> None:
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.FINISHED:
                continue
            ret = self.block_manager.append(seq)
            if ret is not None:
                src_block, dst_block = ret
Woosuk Kwon's avatar
Woosuk Kwon committed
70
                self.blocks_to_copy[src_block] = dst_block
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72

    def _swap_in(self, seq_group: SequenceGroup) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
        mapping = self.block_manager.swap_in(seq_group)
        self.blocks_to_swap_in.update(mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.SWAPPED:
                seq.status = SequenceStatus.RUNNING
78
        self.running.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81

    def _swap_out(self, seq_group: SequenceGroup) -> None:
        assert self.block_manager.can_swap_out(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
        mapping = self.block_manager.swap_out(seq_group)
        self.blocks_to_swap_out.update(mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
87
88
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.RUNNING:
                seq.status = SequenceStatus.SWAPPED
        self.swapped.append(seq_group)

Woosuk Kwon's avatar
Woosuk Kwon committed
89
    def prepare(self) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
93
        # 1. Prepare new slots for the running sequences.
        # NOTE: Here we implicitly assume FCFS scheduling.
        # That is, the most recently added sequence group is the first
        # to be swapped out.
94
95
        victim_idx = len(self.running) - 1
        for i, seq_group in enumerate(self.running):
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
            if i > victim_idx:
                # The i-th sequence group has already been swapped out.
                break
            # OOM. Swap out the victim sequence groups.
            while not self.block_manager.can_append(seq_group):
101
                victim_seq_group = self.running[victim_idx]
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
106
107
108
                self._swap_out(victim_seq_group)
                victim_idx -= 1
                if i > victim_idx:
                    # No other sequence groups can be swapped out.
                    break
            else:
                self._append(seq_group)
109
        self.running = self.running[:victim_idx + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113
114
115
116

        # 2. Swap in the swapped sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
        # The swapped sequences are in LIFO order.
        for i, seq_group in enumerate(reversed(self.swapped)):
            if self.block_manager.can_swap_in(seq_group):
                self._swap_in(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
117
                self._append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
121
122
123
124
125
            else:
                # OOM. Stop swapping.
                self.swapped = self.swapped[:len(self.swapped) - i]
                break
        else:
            # All swapped sequences are swapped in.
            self.swapped.clear()

126
127
128
129
130
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
        # 3. Join new sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
133
        # TODO(woosuk): Add a batching policy to control the batch size.
Woosuk Kwon's avatar
Woosuk Kwon committed
134
        if not self.swapped:
135
            # FIXME(woosuk): Acquire a lock to protect pending.
Woosuk Kwon's avatar
Woosuk Kwon committed
136
            for i, seq_group in enumerate(self.pending):
137
                num_prompt_tokens = seq_group.seqs[0].get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
138
                if self.block_manager.can_allocate(seq_group):
139
140
141
142
143
144
145
146
                    if (num_batched_tokens + num_prompt_tokens
                        <= _MAX_NUM_BATCHED_TOKENS):
                        self._allocate(seq_group)
                        num_batched_tokens += num_prompt_tokens
                        continue

                self.pending = self.pending[i:]
                break
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
            else:
                self.pending.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
149

Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
    def step(self) -> None:
        # Ensure that either swap-in or swap-out is performed.
Woosuk Kwon's avatar
Woosuk Kwon committed
152
153
        if self.blocks_to_swap_in:
            assert not self.blocks_to_swap_out
Woosuk Kwon's avatar
Woosuk Kwon committed
154

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        # Create input data structures.
        prompt_tokens: Dict[int, List[int]] = {}
        generation_tokens: Dict[int, int] = {}
        context_lens: Dict[int, int] = {}
        block_tables: Dict[int, List[int]] = {}
        for seq_group in self.running:
            group_id = seq_group.group_id
            num_steps = self.num_steps[group_id]
            # NOTE(woosuk): We assume that the number of steps is 0
            # for the prompt sequences.
            is_prompt = num_steps == 0
            for seq in seq_group.seqs:
                if seq.status != SequenceStatus.RUNNING:
                    continue

                seq_id = seq.seq_id
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
                if is_prompt:
                    prompt_tokens[seq_id] = seq.get_token_ids()
                else:
                    generation_tokens[seq_id] = seq.get_token_ids()[-1]
                    context_lens[seq_id] = seq.get_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
178
179
        # Execute the first stage of the pipeline.
        self.controllers[0].execute_stage(
180
181
182
183
            prompt_tokens,
            generation_tokens,
            context_lens,
            block_tables,
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
186
187
            self.blocks_to_swap_in.copy(),
            self.blocks_to_swap_out.copy(),
            self.blocks_to_copy.copy(),
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
188

Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
191
192
193
        # Clear for the next step.
        self.blocks_to_swap_in.clear()
        self.blocks_to_swap_out.clear()
        self.blocks_to_copy.clear()

Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
196
197
198
    def post_step(
        self,
        next_tokens: Dict[int, Tuple[int, int]],
    ) -> None:
        # Update the running sequences and free blocks.
199
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
            group_id = seq_group.group_id
            self.num_steps[group_id] += 1
            stop_token_ids = self.stop_token_ids[group_id]

            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

                parent_seq_id, next_token = next_tokens[seq.seq_id]
                if seq.seq_id != parent_seq_id:
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
                    parent_seq = seq_group.find(parent_seq_id)
                    seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
                    self.block_manager.fork(parent_seq, seq)

                # Append a new token to the sequence.
                seq.append(next_token)

                # Check if the sequence has generated a stop token.
                if next_token in stop_token_ids:
                    self._free_seq(seq)
                    continue

                # Check if the sequence has reached the maximum number of steps.
                if self.num_steps[group_id] == self.max_num_steps[group_id]:
                    self._free_seq(seq)
                    continue

231
232
233
        # Update the running sequences.
        running: List[SequenceGroup] = []
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
234
            if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs):
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
                del self.num_steps[seq_group.group_id]
                del self.max_num_steps[seq_group.group_id]
                del self.stop_token_ids[seq_group.group_id]
                # TODO: Return the seq_group to the client.
            else:
240
241
                running.append(seq_group)
        self.running = running