scheduler.py 16.4 KB
Newer Older
1
2
import enum
import time
3
from typing import Dict, List, Optional, Tuple
Woosuk Kwon's avatar
Woosuk Kwon committed
4

5
from cacheflow.config import CacheConfig, SchedulerConfig
6
7
from cacheflow.core.block_manager import BlockSpaceManager
from cacheflow.core.policy import PolicyFactory
8
9
10
11
from cacheflow.logger import init_logger
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
                                SequenceGroupMetadata, SequenceOutputs,
                                SequenceStatus)
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
logger = init_logger(__name__)
14

Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
_LOGGING_INTERVAL_SEC = 10


18
19
20
21
22
23
24
25
26
27
28
29
30
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class SchedulerOutputs:

    def __init__(
        self,
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        self.blocks_to_swap_in = blocks_to_swap_in
        self.blocks_to_swap_out = blocks_to_swap_out
        self.blocks_to_copy = blocks_to_copy
        # Swap in and swap out should never happen at the same time.
        assert not (blocks_to_swap_in and blocks_to_swap_out)

    def is_empty(self) -> bool:
        return (not self.blocks_to_swap_in
                and not self.blocks_to_swap_out
                and not self.blocks_to_copy)


Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
53
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
54
        self,
55
56
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
57
        log_stats: bool,
Woosuk Kwon's avatar
Woosuk Kwon committed
58
    ) -> None:
59
60
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
Woosuk Kwon's avatar
Woosuk Kwon committed
61
        self.log_stats = log_stats
Woosuk Kwon's avatar
Woosuk Kwon committed
62

63
64
        # Instantiate the scheduling policy.
        self.policy = PolicyFactory.get_policy(policy_name='fcfs')
Woosuk Kwon's avatar
Woosuk Kwon committed
65
        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
66
        self.block_manager = BlockSpaceManager(
67
68
69
            block_size=self.cache_config.block_size,
            num_gpu_blocks=self.cache_config.num_gpu_blocks,
            num_cpu_blocks=self.cache_config.num_cpu_blocks,
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
        )

72
73
74
        # Sequence groups in the WAITING state.
        self.waiting: List[SequenceGroup] = []
        # Sequence groups in the RUNNING state.
75
        self.running: List[SequenceGroup] = []
76
        # Sequence groups in the SWAPPED state.
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        self.swapped: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
78

Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
        self.last_logging_time: float = 0.0
        # List[timestamp, num_tokens]
        self.num_input_tokens: List[Tuple[float, int]] = []
82

83
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
84
        # Add sequence groups to the waiting queue.
85
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
86

87
88
89
90
    def has_unfinished_seqs(self) -> bool:
        return self.waiting or self.running or self.swapped

    def _schedule(self) -> Tuple[SchedulerOutputs, List[int]]:
91
92
93
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
94
        blocks_to_copy: Dict[int, List[int]] = {}
95

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        # Fix the current time.
        now = time.time()

        # NOTE(woosuk): We prioritize the sequence groups in the RUNNING state
        # in order to minimize the preemption overheads.
        # Preemption happens only when there is no available slot to keep all
        # the sequence groups in the RUNNING state.
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
        running: List[SequenceGroup] = []
        preempted: List[SequenceGroup] = []
        while self.running:
            seq_group = self.running.pop(0)
112
            while not self.block_manager.can_append_slot(seq_group):
113
114
115
116
117
118
119
120
121
122
                if self.running:
                    # Preempt the lowest-priority sequence groups.
                    victim_seq_group = self.running.pop(-1)
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
                    break
            else:
125
                # Append new slots to the sequence group.
126
                self._append_slot(seq_group, blocks_to_copy)
127
128
129
130
131
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
132
        while self.swapped and not blocks_to_swap_out:
133
134
135
136
137
138
            seq_group = self.swapped[0]
            # If the sequence group has been preempted in this step, stop.
            if seq_group in preempted:
                break
            # If the sequence group cannot be swapped in, stop.
            if not self.block_manager.can_swap_in(seq_group):
Woosuk Kwon's avatar
Woosuk Kwon committed
139
140
                break

141
142
            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
143
144
145
            num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
            num_curr_seqs = len(self.running)
            if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
146
147
                break

148
149
            seq_group = self.swapped.pop(0)
            self._swap_in(seq_group, blocks_to_swap_in)
150
            self._append_slot(seq_group, blocks_to_copy)
151
            self.running.append(seq_group)
152

153
154
155
156
157
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
            for seq_group in self.running
        )

158
        # Join waiting sequences if possible.
159
        prompt_group_ids: List[str] = []
160
161
162
163
        # NOTE(woosuk): The sequence groups in the SWAPPED state are strictly
        # prioritized over the sequence groups in the WAITING state.
        # This is because we want to bound the amount of CPU memory taken by
        # the swapped sequence groups.
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        if not self.swapped:
165
166
167
            # Optimization: We do not sort the waiting queue since the preempted
            # sequence groups are added to the front and the new sequence groups
            # are added to the back.
168
169
170
171
172
173
174
175
176
177
            while self.waiting:
                seq_group = self.waiting[0]
                # If the sequence group has been preempted in this step, stop.
                if seq_group in preempted:
                    break
                # If the sequence group cannot be allocated, stop.
                if not self.block_manager.can_allocate(seq_group):
                    break

                # If the number of batched tokens exceeds the limit, stop.
178
                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
179
                if (num_batched_tokens + num_prompt_tokens
180
                    > self.scheduler_config.max_num_batched_tokens):
181
182
                    break

183
184
                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
185
186
187
                num_new_seqs = seq_group.num_seqs(status=SequenceStatus.WAITING)
                num_curr_seqs = len(self.running)
                if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs:
188
189
                    break

190
191
192
193
                seq_group = self.waiting.pop(0)
                self._allocate(seq_group)
                self.running.append(seq_group)
                num_batched_tokens += num_prompt_tokens
194
                prompt_group_ids.append(seq_group.request_id)
Woosuk Kwon's avatar
Woosuk Kwon committed
195

196
197
198
199
200
        scheduler_outputs = SchedulerOutputs(
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
201
        if not self.log_stats:
202
            return scheduler_outputs, prompt_group_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
203

204
        # TODO(woosuk): Move the below code to server.
Woosuk Kwon's avatar
Woosuk Kwon committed
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        now = time.time()
        if num_batched_tokens > 0:
            self.num_input_tokens.append((now, num_batched_tokens))
        elapsed_time = now - self.last_logging_time
        if elapsed_time > _LOGGING_INTERVAL_SEC:
            self.last_logging_time = now
            self.num_input_tokens = [
                (t, n) for t, n in self.num_input_tokens
                if now - t < _LOGGING_INTERVAL_SEC
            ]
            if len(self.num_input_tokens) > 1:
                total_num_tokens = sum(n for _, n in self.num_input_tokens[:-1])
                window = now - self.num_input_tokens[0][0]
                avg_throughput = total_num_tokens / window
            else:
                avg_throughput = 0.0

222
            total_num_gpu_blocks = self.cache_config.num_gpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
223
            num_free_gpu_blocks = self.block_manager.get_num_free_gpu_blocks()
224
225
226
227
228
            num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks
            gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks

            total_num_cpu_blocks = self.cache_config.num_cpu_blocks
            if total_num_cpu_blocks > 0:
229
                num_free_cpu_blocks = self.block_manager.get_num_free_cpu_blocks()
230
231
                num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks
                cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks
Woosuk Kwon's avatar
Woosuk Kwon committed
232
233
234
235
236
237
238
239
240
241
            else:
                cpu_cache_usage = 0.0

            logger.info(
                f"Throughput: {avg_throughput:.1f} tokens/s, "
                f"Running: {len(self.running)} reqs, "
                f"Swapped: {len(self.swapped)} reqs, "
                f"Pending: {len(self.waiting)} reqs, "
                f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
                f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
242
        return scheduler_outputs, prompt_group_ids
Woosuk Kwon's avatar
Woosuk Kwon committed
243

244
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
245
246
247
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
248
        scheduler_outputs, prompt_group_ids = self._schedule()
249
250

        # Create input data structures.
251
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
252
        for seq_group in self.running:
253
            is_prompt = seq_group.request_id in prompt_group_ids
254

255
            seq_data: Dict[int, List[SequenceData]] = {}
256
257
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
258
                seq_id = seq.seq_id
259
                seq_data[seq_id] = seq.data
260
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
261

262
            seq_group_metadata = SequenceGroupMetadata(
263
                request_id=seq_group.request_id,
264
                is_prompt=is_prompt,
265
                seq_data=seq_data,
266
                sampling_params=seq_group.sampling_params,
267
268
                block_tables=block_tables,
            )
269
            seq_group_metadata_list.append(seq_group_metadata)
270
        return seq_group_metadata_list, scheduler_outputs
271

272
    def update(
Woosuk Kwon's avatar
Woosuk Kwon committed
273
        self,
274
        seq_outputs: Dict[int, SequenceOutputs],
275
    ) -> List[SequenceGroup]:
Woosuk Kwon's avatar
Woosuk Kwon committed
276
        # Update the running sequences and free blocks.
277
        for seq_group in self.running:
278
279
            # Process beam search results before processing the new tokens.
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
280
281
                output = seq_outputs[seq.seq_id]
                if seq.seq_id != output.parent_seq_id:
Woosuk Kwon's avatar
Woosuk Kwon committed
282
283
284
285
                    # The sequence is a fork of the parent sequence (beam search).
                    # Free the current sequence.
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
286
287
                    parent_seq = seq_group.find(output.parent_seq_id)
                    parent_seq.fork(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
                    self.block_manager.fork(parent_seq, seq)

290
291
            # Process the new tokens.
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Woosuk Kwon's avatar
Woosuk Kwon committed
292
                # Append a new token to the sequence.
293
                output = seq_outputs[seq.seq_id]
294
                seq.append_token(output.output_token, output.logprobs)
295
        return self.running.copy()
Woosuk Kwon's avatar
Woosuk Kwon committed
296

297
298
299
    def free_seq(self, seq: Sequence) -> None:
        seq.status = SequenceStatus.FINISHED
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
300

301
302
303
304
305
    def free_finished_seq_groups(self) -> None:
        self.running = [
            seq_group for seq_group in self.running
            if not seq_group.is_finished()
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
306

307
308
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
309
        for seq in seq_group.get_seqs():
310
311
            seq.status = SequenceStatus.RUNNING

312
    def _append_slot(
313
314
315
316
317
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
318
            ret = self.block_manager.append_slot(seq)
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
        # (e.g., beam search), recomputation is not supported. In such a case,
        # we use swapping instead.
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
            seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
            if len(seqs) == 1:
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
            assert False, 'Invalid preemption mode.'

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
            self.block_manager.free(seq)
365
366
367
        # NOTE: For FCFS, we insert the preempted sequence group to the front
        # of the waiting queue.
        self.waiting.insert(0, seq_group)
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        for seq in seqs:
            seq.status = SequenceStatus.SWAPPED
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        assert self.block_manager.can_swap_out(seq_group)
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED