scheduler.py 9.68 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
2
3
from typing import Dict, List, Tuple

from cacheflow.master.block_manager import BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
4
5
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus

10
11
_MAX_NUM_BATCHED_TOKENS = 2048

Woosuk Kwon's avatar
Woosuk Kwon committed
12
13
14

class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
15
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
16
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
17
        frontend: Frontend,
Woosuk Kwon's avatar
Woosuk Kwon committed
18
        controllers: List,
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
21
22
        block_size: int,
        num_gpu_blocks: int,
        num_cpu_blocks: int,
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
23
        self.frontend = frontend
Woosuk Kwon's avatar
Woosuk Kwon committed
24
25
26
27
28
29
        self.controllers = controllers
        self.block_size = block_size
        self.num_gpu_blocks = num_gpu_blocks
        self.num_cpu_blocks = num_cpu_blocks

        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
30
31
32
33
34
35
        self.block_manager = BlockSpaceManager(
            block_size=block_size,
            num_gpu_blocks=num_gpu_blocks,
            num_cpu_blocks=num_cpu_blocks,
        )

36
37
        # Running sequence groups (FIFO).
        self.running: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
        # Mapping: group_id -> num_steps.
        self.num_steps: Dict[int, int] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
        # Mapping: group_id -> sampling params.
        self.sampling_params: Dict[int, SamplingParams] = {}
Woosuk Kwon's avatar
Woosuk Kwon committed
42
43
44
45

        # Swapped sequence groups (LIFO).
        self.swapped: List[SequenceGroup] = []
        # Pending sequence groups (FIFO).
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
        self.pending: List[SequenceGroup] = []

Woosuk Kwon's avatar
Woosuk Kwon committed
48
49
50
51
52
53
    def _fetch_inputs(self) -> None:
        inputs = self.frontend.get_inputs()
        for seq_group, sampling_params in inputs:
            self.pending.append(seq_group)
            self.sampling_params[seq_group.group_id] = sampling_params

Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
59
60
61
    def _free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)

    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
        for seq in seq_group.seqs:
            seq.status = SequenceStatus.RUNNING
62
        self.running.append(seq_group)
63
        # FIXME(woosuk): Support interactive generation.
Woosuk Kwon's avatar
Woosuk Kwon committed
64
        self.num_steps[seq_group.group_id] = 0
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
67
68
69
70
    def _append(
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74
75
76
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.FINISHED:
                continue
            ret = self.block_manager.append(seq)
            if ret is not None:
                src_block, dst_block = ret
77
                blocks_to_copy[src_block] = dst_block
Woosuk Kwon's avatar
Woosuk Kwon committed
78

79
80
81
82
83
    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        mapping = self.block_manager.swap_in(seq_group)
85
        blocks_to_swap_in.update(mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.SWAPPED:
                seq.status = SequenceStatus.RUNNING
89
        self.running.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
90

91
92
93
94
95
    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
        assert self.block_manager.can_swap_out(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
97
        mapping = self.block_manager.swap_out(seq_group)
98
        blocks_to_swap_out.update(mapping)
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102
103
        for seq in seq_group.seqs:
            if seq.status == SequenceStatus.RUNNING:
                seq.status = SequenceStatus.SWAPPED
        self.swapped.append(seq_group)

Woosuk Kwon's avatar
Woosuk Kwon committed
104
    def step(self) -> None:
105
106
107
108
109
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
        blocks_to_copy: Dict[int, int] = {}

Woosuk Kwon's avatar
Woosuk Kwon committed
110
        # 1. Reserve new slots for the running sequences.
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
        # NOTE: Here we implicitly assume FCFS scheduling.
        # That is, the most recently added sequence group is the first
        # to be swapped out.
114
115
        victim_idx = len(self.running) - 1
        for i, seq_group in enumerate(self.running):
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
119
120
            if i > victim_idx:
                # The i-th sequence group has already been swapped out.
                break
            # OOM. Swap out the victim sequence groups.
            while not self.block_manager.can_append(seq_group):
121
                victim_seq_group = self.running[victim_idx]
122
                self._swap_out(victim_seq_group, blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
125
126
127
                victim_idx -= 1
                if i > victim_idx:
                    # No other sequence groups can be swapped out.
                    break
            else:
128
                self._append(seq_group, blocks_to_copy)
129
        self.running = self.running[:victim_idx + 1]
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
133
134
135

        # 2. Swap in the swapped sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
        # The swapped sequences are in LIFO order.
        for i, seq_group in enumerate(reversed(self.swapped)):
            if self.block_manager.can_swap_in(seq_group):
136
137
                self._swap_in(seq_group, blocks_to_swap_in)
                self._append(seq_group, blocks_to_copy)
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140
141
142
143
144
145
            else:
                # OOM. Stop swapping.
                self.swapped = self.swapped[:len(self.swapped) - i]
                break
        else:
            # All swapped sequences are swapped in.
            self.swapped.clear()

146
147
148
149
150
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
        # 3. Join new sequences if possible.
        # NOTE: Here we implicitly assume FCFS scheduling.
153
        # TODO(woosuk): Add a batching policy to control the batch size.
Woosuk Kwon's avatar
Woosuk Kwon committed
154
        if not self.swapped:
155
            # FIXME(woosuk): Acquire a lock to protect pending.
Woosuk Kwon's avatar
Woosuk Kwon committed
156
            self._fetch_inputs()
Woosuk Kwon's avatar
Woosuk Kwon committed
157
            for i, seq_group in enumerate(self.pending):
158
                num_prompt_tokens = seq_group.seqs[0].get_len()
Woosuk Kwon's avatar
Woosuk Kwon committed
159
                if self.block_manager.can_allocate(seq_group):
160
161
162
163
164
165
166
167
                    if (num_batched_tokens + num_prompt_tokens
                        <= _MAX_NUM_BATCHED_TOKENS):
                        self._allocate(seq_group)
                        num_batched_tokens += num_prompt_tokens
                        continue

                self.pending = self.pending[i:]
                break
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
            else:
                self.pending.clear()
Woosuk Kwon's avatar
Woosuk Kwon committed
170

171
172
173
        # Ensure that swap-in and swap-out never happen at the same timestep.
        if blocks_to_swap_in:
            assert not blocks_to_swap_out
Woosuk Kwon's avatar
Woosuk Kwon committed
174

Woosuk Kwon's avatar
Woosuk Kwon committed
175
        # 4. Create input data structures.
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        prompt_tokens: Dict[int, List[int]] = {}
        generation_tokens: Dict[int, int] = {}
        context_lens: Dict[int, int] = {}
        block_tables: Dict[int, List[int]] = {}
        for seq_group in self.running:
            group_id = seq_group.group_id
            num_steps = self.num_steps[group_id]
            # NOTE(woosuk): We assume that the number of steps is 0
            # for the prompt sequences.
            is_prompt = num_steps == 0
            for seq in seq_group.seqs:
                if seq.status != SequenceStatus.RUNNING:
                    continue

                seq_id = seq.seq_id
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
                if is_prompt:
                    prompt_tokens[seq_id] = seq.get_token_ids()
                else:
                    generation_tokens[seq_id] = seq.get_token_ids()[-1]
                    context_lens[seq_id] = seq.get_len()

Woosuk Kwon's avatar
Woosuk Kwon committed
198
        # 5. Execute the first stage of the pipeline.
Woosuk Kwon's avatar
Woosuk Kwon committed
199
        self.controllers[0].execute_stage(
200
201
202
203
            prompt_tokens,
            generation_tokens,
            context_lens,
            block_tables,
204
205
206
            blocks_to_swap_in,
            blocks_to_swap_out,
            blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
207
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
208

Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
212
213
    def post_step(
        self,
        next_tokens: Dict[int, Tuple[int, int]],
    ) -> None:
        # Update the running sequences and free blocks.
214
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
            group_id = seq_group.group_id
            self.num_steps[group_id] += 1
Woosuk Kwon's avatar
Woosuk Kwon committed
217
            stop_token_ids = self.sampling_params[group_id].stop_token_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

            for seq in seq_group.seqs:
                if seq.status == SequenceStatus.FINISHED:
                    continue

                parent_seq_id, next_token = next_tokens[seq.seq_id]
                if seq.seq_id != parent_seq_id:
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
                    parent_seq = seq_group.find(parent_seq_id)
                    seq.logical_token_blocks = parent_seq.logical_token_blocks.copy()
                    self.block_manager.fork(parent_seq, seq)

                # Append a new token to the sequence.
Woosuk Kwon's avatar
Minor  
Woosuk Kwon committed
234
                seq.append([next_token])
Woosuk Kwon's avatar
Woosuk Kwon committed
235
236
237
238
239
240
241

                # Check if the sequence has generated a stop token.
                if next_token in stop_token_ids:
                    self._free_seq(seq)
                    continue

                # Check if the sequence has reached the maximum number of steps.
Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
                max_num_steps = self.sampling_params[group_id].max_num_steps
                if self.num_steps[group_id] == max_num_steps:
Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
                    self._free_seq(seq)
                    continue

247
248
249
        # Update the running sequences.
        running: List[SequenceGroup] = []
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
            if seq_group.is_finished():
                self._return(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
252
            else:
253
254
                running.append(seq_group)
        self.running = running
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257
258
259
260

    def _return(self, seq_group: SequenceGroup) -> None:
        group_id = seq_group.group_id
        del self.num_steps[group_id]
        del self.sampling_params[group_id]
        self.frontend.print_response(seq_group)