scheduler.py 17.9 KB
Newer Older
1
from collections import deque
2
3
import enum
import time
4
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.config import CacheConfig, SchedulerConfig
7
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
Woosuk Kwon's avatar
Woosuk Kwon committed
8
9
10
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
11
                           SequenceGroupMetadata, SequenceStatus)
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
logger = init_logger(__name__)
14

Woosuk Kwon's avatar
Woosuk Kwon committed
15

16
17
18
19
20
21
22
23
24
25
26
27
28
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


29
30
31
32
class SchedulerOutputs:

    def __init__(
        self,
33
        scheduled_seq_groups: Iterable[SequenceGroup],
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
        prompt_run: bool,
        num_batched_tokens: int,
36
37
38
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
Woosuk Kwon's avatar
Woosuk Kwon committed
39
        ignored_seq_groups: List[SequenceGroup],
40
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
        self.scheduled_seq_groups = scheduled_seq_groups
        self.prompt_run = prompt_run
        self.num_batched_tokens = num_batched_tokens
44
45
46
47
48
        self.blocks_to_swap_in = blocks_to_swap_in
        self.blocks_to_swap_out = blocks_to_swap_out
        self.blocks_to_copy = blocks_to_copy
        # Swap in and swap out should never happen at the same time.
        assert not (blocks_to_swap_in and blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
49
        self.ignored_seq_groups = ignored_seq_groups
50
51

    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
55
56


Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
59
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
60
        self,
61
62
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
63
    ) -> None:
64
65
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
Woosuk Kwon's avatar
Woosuk Kwon committed
66

67
68
69
        self.prompt_limit = min(self.scheduler_config.max_model_len,
                                self.scheduler_config.max_num_batched_tokens)

70
        # Instantiate the scheduling policy.
71
        self.policy = PolicyFactory.get_policy(policy_name="fcfs")
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        self.block_manager = BlockSpaceManager(
74
75
76
            block_size=self.cache_config.block_size,
            num_gpu_blocks=self.cache_config.num_gpu_blocks,
            num_cpu_blocks=self.cache_config.num_cpu_blocks,
77
            sliding_window=self.cache_config.sliding_window)
Woosuk Kwon's avatar
Woosuk Kwon committed
78

79
        # Sequence groups in the WAITING state.
80
        self.waiting: Deque[SequenceGroup] = deque()
81
        # Sequence groups in the RUNNING state.
82
        self.running: Deque[SequenceGroup] = deque()
83
        # Sequence groups in the SWAPPED state.
84
        self.swapped: Deque[SequenceGroup] = deque()
Woosuk Kwon's avatar
Woosuk Kwon committed
85

86
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
87
        # Add sequence groups to the waiting queue.
88
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
89

Antoni Baum's avatar
Antoni Baum committed
90
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
91
92
93
94
95
96
97
98
99
100
101
102
        """Aborts a sequence group with the given ID.

        Check if the sequence group with the given ID
            is present in any of the state queue.
        If present, remove the sequence group from the state queue.
            Also, if any of the sequences in the sequence group is not finished,
                free the sequence with status `FINISHED_ABORTED`.
        Otherwise, do nothing.

        Args:
            request_id: The ID(s) of the sequence group to abort.
        """
Antoni Baum's avatar
Antoni Baum committed
103
104
105
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
106
        for state_queue in [self.waiting, self.running, self.swapped]:
107
108
109
110
111
112
            aborted_groups = []
            for seq_group in state_queue:
                if not request_ids:
                    # Using 'break' here may add two extra iterations,
                    # but is acceptable to reduce complexity .
                    break
Antoni Baum's avatar
Antoni Baum committed
113
                if seq_group.request_id in request_ids:
114
115
                    # Appending aborted group into pending list.
                    aborted_groups.append(seq_group)
Antoni Baum's avatar
Antoni Baum committed
116
                    request_ids.remove(seq_group.request_id)
117
118
119
120
121
122
123
124
            for aborted_group in aborted_groups:
                # Remove the sequence group from the state queue.
                state_queue.remove(aborted_group)
                for seq in seq_group.get_seqs():
                    if seq.is_finished():
                        continue
                    seq.status = SequenceStatus.FINISHED_ABORTED
                    self.free_seq(seq)
125

126
127
128
    def has_unfinished_seqs(self) -> bool:
        return self.waiting or self.running or self.swapped

129
130
131
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Woosuk Kwon's avatar
Woosuk Kwon committed
132
    def _schedule(self) -> SchedulerOutputs:
133
134
135
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
136
        blocks_to_copy: Dict[int, List[int]] = {}
137

138
        # Fix the current time.
139
        now = time.monotonic()
140

Woosuk Kwon's avatar
Woosuk Kwon committed
141
142
143
144
        # Join waiting sequences if possible.
        if not self.swapped:
            ignored_seq_groups: List[SequenceGroup] = []
            scheduled: List[SequenceGroup] = []
145
146
147
148
            # The total number of sequences on the fly, including the
            # requests in the generation phase.
            num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                                for seq_group in self.running)
149
150
            seq_lens: List[int] = []

Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
154
155
156
            # Optimization: We do not sort the waiting queue since the preempted
            # sequence groups are added to the front and the new sequence groups
            # are added to the back.
            while self.waiting:
                seq_group = self.waiting[0]

157
158
159
                waiting_seqs = seq_group.get_seqs(
                    status=SequenceStatus.WAITING)
                assert len(waiting_seqs) == 1, (
160
161
                    "Waiting sequence group should have only one prompt "
                    "sequence.")
162
                num_prompt_tokens = waiting_seqs[0].get_len()
163
                if num_prompt_tokens > self.prompt_limit:
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
                    logger.warning(
                        f"Input prompt ({num_prompt_tokens} tokens) is too long"
166
                        f" and exceeds limit of {self.prompt_limit}")
167
                    for seq in waiting_seqs:
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
                        seq.status = SequenceStatus.FINISHED_IGNORED
                    ignored_seq_groups.append(seq_group)
170
                    self.waiting.popleft()
171
                    continue
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173

                # If the sequence group cannot be allocated, stop.
174
175
                can_allocate = self.block_manager.can_allocate(seq_group)
                if can_allocate == AllocStatus.LATER:
Woosuk Kwon's avatar
Woosuk Kwon committed
176
                    break
177
178
179
180
                elif can_allocate == AllocStatus.NEVER:
                    logger.warning(
                        f"Input prompt ({num_prompt_tokens} tokens) is too long"
                        f" and exceeds the capacity of block_manager")
181
                    for seq in waiting_seqs:
182
183
                        seq.status = SequenceStatus.FINISHED_IGNORED
                    ignored_seq_groups.append(seq_group)
184
                    self.waiting.popleft()
185
                    continue
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187

                # If the number of batched tokens exceeds the limit, stop.
188
189
190
                new_seq_lens = seq_lens + [num_prompt_tokens]
                num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
                if (num_batched_tokens >
Woosuk Kwon's avatar
Woosuk Kwon committed
191
192
193
194
195
                        self.scheduler_config.max_num_batched_tokens):
                    break

                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
196
                num_new_seqs = seq_group.get_max_num_running_seqs()
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
200
                if (num_curr_seqs + num_new_seqs >
                        self.scheduler_config.max_num_seqs):
                    break

201
202
203
204
205
                num_paddings = num_batched_tokens - sum(new_seq_lens)
                if num_paddings > self.scheduler_config.max_paddings:
                    break
                seq_lens = new_seq_lens

206
                seq_group = self.waiting.popleft()
Woosuk Kwon's avatar
Woosuk Kwon committed
207
208
                self._allocate(seq_group)
                self.running.append(seq_group)
209
                num_curr_seqs += num_new_seqs
Woosuk Kwon's avatar
Woosuk Kwon committed
210
211
                scheduled.append(seq_group)

212
            if scheduled or ignored_seq_groups:
Woosuk Kwon's avatar
Woosuk Kwon committed
213
214
215
                scheduler_outputs = SchedulerOutputs(
                    scheduled_seq_groups=scheduled,
                    prompt_run=True,
Zhuofan's avatar
Zhuofan committed
216
217
                    num_batched_tokens=len(seq_lens) *
                    max(seq_lens) if seq_lens else 0,
Woosuk Kwon's avatar
Woosuk Kwon committed
218
219
220
221
222
223
224
225
226
                    blocks_to_swap_in=blocks_to_swap_in,
                    blocks_to_swap_out=blocks_to_swap_out,
                    blocks_to_copy=blocks_to_copy,
                    ignored_seq_groups=ignored_seq_groups,
                )
                return scheduler_outputs

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
227
228
229
230
231
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
232
        running: Deque[SequenceGroup] = deque()
233
234
        preempted: List[SequenceGroup] = []
        while self.running:
235
            seq_group = self.running.popleft()
236
            while not self.block_manager.can_append_slot(seq_group):
237
238
                if self.running:
                    # Preempt the lowest-priority sequence groups.
239
                    victim_seq_group = self.running.pop()
240
241
242
243
244
245
246
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
                    break
            else:
249
                # Append new slots to the sequence group.
250
                self._append_slot(seq_group, blocks_to_copy)
251
252
253
254
255
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
256
257
258
259
260
261
262
263
264
        if not preempted:
            num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
                                for seq_group in self.running)

            while self.swapped:
                seq_group = self.swapped[0]
                # If the sequence group cannot be swapped in, stop.
                if not self.block_manager.can_swap_in(seq_group):
                    break
265

266
267
268
269
270
271
272
                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
                num_new_seqs = seq_group.get_max_num_running_seqs()
                if (num_curr_seqs + num_new_seqs >
                        self.scheduler_config.max_num_seqs):
                    break

273
                seq_group = self.swapped.popleft()
274
275
276
277
278
279
280
281
                self._swap_in(seq_group, blocks_to_swap_in)
                self._append_slot(seq_group, blocks_to_copy)
                num_curr_seqs += num_new_seqs
                self.running.append(seq_group)

        # Each sequence in the generation phase only takes one token slot.
        # Therefore, the number of batched tokens is equal to the number of
        # sequences in the RUNNING state.
282
283
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
284
            for seq_group in self.running)
285

286
        scheduler_outputs = SchedulerOutputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
289
            scheduled_seq_groups=self.running,
            prompt_run=False,
            num_batched_tokens=num_batched_tokens,
290
291
292
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
293
            ignored_seq_groups=[],
294
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
295
        return scheduler_outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
296

Woosuk Kwon's avatar
Woosuk Kwon committed
297
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
298
299
300
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Woosuk Kwon's avatar
Woosuk Kwon committed
301
        scheduler_outputs = self._schedule()
302
303

        # Create input data structures.
304
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
305
        for seq_group in scheduler_outputs.scheduled_seq_groups:
Light Lin's avatar
Light Lin committed
306
            seq_data: Dict[int, SequenceData] = {}
307
308
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
309
                seq_id = seq.seq_id
310
                seq_data[seq_id] = seq.data
311
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
312

313
            seq_group_metadata = SequenceGroupMetadata(
314
                request_id=seq_group.request_id,
Woosuk Kwon's avatar
Woosuk Kwon committed
315
                is_prompt=scheduler_outputs.prompt_run,
316
                seq_data=seq_data,
317
                sampling_params=seq_group.sampling_params,
318
319
                block_tables=block_tables,
            )
320
            seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
321
        return seq_group_metadata_list, scheduler_outputs
322

323
324
    def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
        self.block_manager.fork(parent_seq, child_seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
325

326
    def free_seq(self, seq: Sequence) -> None:
327
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
328

329
330
331
332
333
    def free_finished_seq_groups(self) -> None:
        self.running = [
            seq_group for seq_group in self.running
            if not seq_group.is_finished()
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
334

335
336
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
337
        for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
338
339
            seq.status = SequenceStatus.RUNNING

340
    def _append_slot(
341
342
343
344
345
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
346
            ret = self.block_manager.append_slot(seq)
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
363
364
        # (e.g., beam search), recomputation is not currently supported. In
        # such a case, we use swapping instead.
365
366
367
368
369
370
371
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
372
            if seq_group.get_max_num_running_seqs() == 1:
373
374
375
376
377
378
379
380
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
381
            raise AssertionError("Invalid preemption mode.")
382
383
384
385
386
387
388
389
390
391

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
            self.block_manager.free(seq)
392
393
        # NOTE: For FCFS, we insert the preempted sequence group to the front
        # of the waiting queue.
394
        self.waiting.appendleft(seq_group)
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
419
420
421
422
423
424
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
425
426
427
428
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED