scheduler.py 16.7 KB
Newer Older
1
2
import enum
import time
Antoni Baum's avatar
Antoni Baum committed
3
from typing import Dict, Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import BlockSpaceManager
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
10
                           SequenceGroupMetadata, SequenceStatus)
Woosuk Kwon's avatar
Woosuk Kwon committed
11

Woosuk Kwon's avatar
Woosuk Kwon committed
12
logger = init_logger(__name__)
13

Woosuk Kwon's avatar
Woosuk Kwon committed
14

15
16
17
18
19
20
21
22
23
24
25
26
27
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


28
29
30
31
class SchedulerOutputs:

    def __init__(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
        scheduled_seq_groups: List[SequenceGroup],
        prompt_run: bool,
        num_batched_tokens: int,
35
36
37
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
Woosuk Kwon's avatar
Woosuk Kwon committed
38
        ignored_seq_groups: List[SequenceGroup],
39
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
        self.scheduled_seq_groups = scheduled_seq_groups
        self.prompt_run = prompt_run
        self.num_batched_tokens = num_batched_tokens
43
44
45
46
47
        self.blocks_to_swap_in = blocks_to_swap_in
        self.blocks_to_swap_out = blocks_to_swap_out
        self.blocks_to_copy = blocks_to_copy
        # Swap in and swap out should never happen at the same time.
        assert not (blocks_to_swap_in and blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
48
        self.ignored_seq_groups = ignored_seq_groups
49
50

    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
54
55


Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
58
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
59
        self,
60
61
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
62
    ) -> None:
63
64
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
Woosuk Kwon's avatar
Woosuk Kwon committed
65

66
67
68
        self.prompt_limit = min(self.scheduler_config.max_model_len,
                                self.scheduler_config.max_num_batched_tokens)

69
        # Instantiate the scheduling policy.
70
        self.policy = PolicyFactory.get_policy(policy_name="fcfs")
Woosuk Kwon's avatar
Woosuk Kwon committed
71
        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        self.block_manager = BlockSpaceManager(
73
74
75
            block_size=self.cache_config.block_size,
            num_gpu_blocks=self.cache_config.num_gpu_blocks,
            num_cpu_blocks=self.cache_config.num_cpu_blocks,
76
            sliding_window=self.cache_config.sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
77

78
        # TODO(zhuohan): Use deque instead of list for better performance.
79
80
81
        # Sequence groups in the WAITING state.
        self.waiting: List[SequenceGroup] = []
        # Sequence groups in the RUNNING state.
82
        self.running: List[SequenceGroup] = []
83
        # Sequence groups in the SWAPPED state.
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        self.swapped: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
85

86
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
87
        # Add sequence groups to the waiting queue.
88
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
89

Antoni Baum's avatar
Antoni Baum committed
90
91
92
93
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
94
        for state_queue in [self.waiting, self.running, self.swapped]:
95
96
97
98
            # We need to reverse the list as we are removing elements
            # from it as we iterate over it. If we don't do it,
            # indices will get messed up and we will skip over elements.
            for seq_group in reversed(state_queue):
Antoni Baum's avatar
Antoni Baum committed
99
                if seq_group.request_id in request_ids:
100
101
                    # Remove the sequence group from the state queue.
                    state_queue.remove(seq_group)
102
                    for seq in seq_group.get_seqs():
103
104
                        if seq.is_finished():
                            continue
105
106
                        seq.status = SequenceStatus.FINISHED_ABORTED
                        self.free_seq(seq)
Antoni Baum's avatar
Antoni Baum committed
107
108
109
                    request_ids.remove(seq_group.request_id)
                    if not request_ids:
                        return
110

111
112
113
    def has_unfinished_seqs(self) -> bool:
        return self.waiting or self.running or self.swapped

114
115
116
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Woosuk Kwon's avatar
Woosuk Kwon committed
117
    def _schedule(self) -> SchedulerOutputs:
118
119
120
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
121
        blocks_to_copy: Dict[int, List[int]] = {}
122

123
        # Fix the current time.
124
        now = time.monotonic()
125

Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
129
        # Join waiting sequences if possible.
        if not self.swapped:
            ignored_seq_groups: List[SequenceGroup] = []
            scheduled: List[SequenceGroup] = []
130
131
132
133
            # The total number of sequences on the fly, including the
            # requests in the generation phase.
            num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                                for seq_group in self.running)
134
135
            seq_lens: List[int] = []

Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140
141
            # Optimization: We do not sort the waiting queue since the preempted
            # sequence groups are added to the front and the new sequence groups
            # are added to the back.
            while self.waiting:
                seq_group = self.waiting[0]

142
143
144
                assert seq_group.num_seqs() == 1, (
                    "Waiting sequence group should have only one prompt "
                    "sequence.")
Woosuk Kwon's avatar
Woosuk Kwon committed
145
                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
146
                if num_prompt_tokens > self.prompt_limit:
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
                    logger.warning(
                        f"Input prompt ({num_prompt_tokens} tokens) is too long"
149
                        f" and exceeds limit of {self.prompt_limit}")
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
152
153
                    for seq in seq_group.get_seqs():
                        seq.status = SequenceStatus.FINISHED_IGNORED
                    ignored_seq_groups.append(seq_group)
                    self.waiting.pop(0)
154
                    continue
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
157
158
159
160

                # If the sequence group cannot be allocated, stop.
                if not self.block_manager.can_allocate(seq_group):
                    break

                # If the number of batched tokens exceeds the limit, stop.
161
162
163
                new_seq_lens = seq_lens + [num_prompt_tokens]
                num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
                if (num_batched_tokens >
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166
167
168
                        self.scheduler_config.max_num_batched_tokens):
                    break

                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
169
                num_new_seqs = seq_group.get_max_num_running_seqs()
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172
173
                if (num_curr_seqs + num_new_seqs >
                        self.scheduler_config.max_num_seqs):
                    break

174
175
176
177
178
                num_paddings = num_batched_tokens - sum(new_seq_lens)
                if num_paddings > self.scheduler_config.max_paddings:
                    break
                seq_lens = new_seq_lens

Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
181
                seq_group = self.waiting.pop(0)
                self._allocate(seq_group)
                self.running.append(seq_group)
182
                num_curr_seqs += num_new_seqs
Woosuk Kwon's avatar
Woosuk Kwon committed
183
184
                scheduled.append(seq_group)

185
            if scheduled or ignored_seq_groups:
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
188
                scheduler_outputs = SchedulerOutputs(
                    scheduled_seq_groups=scheduled,
                    prompt_run=True,
189
                    num_batched_tokens=len(seq_lens) * max(seq_lens),
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192
193
194
195
196
197
198
                    blocks_to_swap_in=blocks_to_swap_in,
                    blocks_to_swap_out=blocks_to_swap_out,
                    blocks_to_copy=blocks_to_copy,
                    ignored_seq_groups=ignored_seq_groups,
                )
                return scheduler_outputs

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
199
200
201
202
203
204
205
206
207
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
        running: List[SequenceGroup] = []
        preempted: List[SequenceGroup] = []
        while self.running:
            seq_group = self.running.pop(0)
208
            while not self.block_manager.can_append_slot(seq_group):
209
210
211
212
213
214
215
216
217
218
                if self.running:
                    # Preempt the lowest-priority sequence groups.
                    victim_seq_group = self.running.pop(-1)
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
219
220
                    break
            else:
221
                # Append new slots to the sequence group.
222
                self._append_slot(seq_group, blocks_to_copy)
223
224
225
226
227
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
228
229
230
231
232
233
234
235
236
        if not preempted:
            num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                                for seq_group in self.running)

            while self.swapped:
                seq_group = self.swapped[0]
                # If the sequence group cannot be swapped in, stop.
                if not self.block_manager.can_swap_in(seq_group):
                    break
237

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
                num_new_seqs = seq_group.get_max_num_running_seqs()
                if (num_curr_seqs + num_new_seqs >
                        self.scheduler_config.max_num_seqs):
                    break

                seq_group = self.swapped.pop(0)
                self._swap_in(seq_group, blocks_to_swap_in)
                self._append_slot(seq_group, blocks_to_copy)
                num_curr_seqs += num_new_seqs
                self.running.append(seq_group)

        # Each sequence in the generation phase only takes one token slot.
        # Therefore, the number of batched tokens is equal to the number of
        # sequences in the RUNNING state.
254
255
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
256
            for seq_group in self.running)
257

258
        scheduler_outputs = SchedulerOutputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261
            scheduled_seq_groups=self.running,
            prompt_run=False,
            num_batched_tokens=num_batched_tokens,
262
263
264
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
265
            ignored_seq_groups=[],
266
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
267
        return scheduler_outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
268

Woosuk Kwon's avatar
Woosuk Kwon committed
269
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
270
271
272
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Woosuk Kwon's avatar
Woosuk Kwon committed
273
        scheduler_outputs = self._schedule()
274
275

        # Create input data structures.
276
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
277
        for seq_group in scheduler_outputs.scheduled_seq_groups:
278
            seq_data: Dict[int, List[SequenceData]] = {}
279
280
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
281
                seq_id = seq.seq_id
282
                seq_data[seq_id] = seq.data
283
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
284

285
            seq_group_metadata = SequenceGroupMetadata(
286
                request_id=seq_group.request_id,
Woosuk Kwon's avatar
Woosuk Kwon committed
287
                is_prompt=scheduler_outputs.prompt_run,
288
                seq_data=seq_data,
289
                sampling_params=seq_group.sampling_params,
290
291
                block_tables=block_tables,
            )
292
            seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
293
        return seq_group_metadata_list, scheduler_outputs
294

295
296
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
297

298
    def free_seq(self, seq: Sequence) -> None:
299
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
300

301
302
303
304
305
    def free_finished_seq_groups(self) -> None:
        self.running = [
            seq_group for seq_group in self.running
            if not seq_group.is_finished()
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
306

307
308
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
309
        for seq in seq_group.get_seqs():
310
311
            seq.status = SequenceStatus.RUNNING

312
    def _append_slot(
313
314
315
316
317
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
318
            ret = self.block_manager.append_slot(seq)
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
335
336
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
337
338
339
340
341
342
343
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
344
            if seq_group.get_max_num_running_seqs() == 1:
345
346
347
348
349
350
351
352
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
353
            assert False, "Invalid preemption mode."
354
355
356
357
358
359
360
361
362
363

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
            self.block_manager.free(seq)
364
365
366
        # NOTE: For FCFS, we insert the preempted sequence group to the front
        # of the waiting queue.
        self.waiting.insert(0, seq_group)
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
391
392
393
394
395
396
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
397
398
399
400
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED