scheduler.py 17 KB
Newer Older
1
2
import enum
import time
Antoni Baum's avatar
Antoni Baum committed
3
from typing import Dict, Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
4

Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import BlockSpaceManager
from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
                           SequenceGroupMetadata, SequenceOutputs,
                           SequenceStatus)
Woosuk Kwon's avatar
Woosuk Kwon committed
12

Woosuk Kwon's avatar
Woosuk Kwon committed
13
logger = init_logger(__name__)
14

Woosuk Kwon's avatar
Woosuk Kwon committed
15

16
17
18
19
20
21
22
23
24
25
26
27
28
class PreemptionMode(enum.Enum):
    """Preemption modes.

    1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
    and swap them back in when the sequences are resumed.
    2. Recomputation: Discard the blocks of the preempted sequences and
    recompute them when the sequences are resumed, treating the sequences as
    new prompts.
    """
    SWAP = enum.auto()
    RECOMPUTE = enum.auto()


29
30
31
32
class SchedulerOutputs:

    def __init__(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
        scheduled_seq_groups: List[SequenceGroup],
        prompt_run: bool,
        num_batched_tokens: int,
36
37
38
        blocks_to_swap_in: Dict[int, int],
        blocks_to_swap_out: Dict[int, int],
        blocks_to_copy: Dict[int, List[int]],
Woosuk Kwon's avatar
Woosuk Kwon committed
39
        ignored_seq_groups: List[SequenceGroup],
40
    ) -> None:
Woosuk Kwon's avatar
Woosuk Kwon committed
41
42
43
        self.scheduled_seq_groups = scheduled_seq_groups
        self.prompt_run = prompt_run
        self.num_batched_tokens = num_batched_tokens
44
45
46
47
48
        self.blocks_to_swap_in = blocks_to_swap_in
        self.blocks_to_swap_out = blocks_to_swap_out
        self.blocks_to_copy = blocks_to_copy
        # Swap in and swap out should never happen at the same time.
        assert not (blocks_to_swap_in and blocks_to_swap_out)
Woosuk Kwon's avatar
Woosuk Kwon committed
49
        self.ignored_seq_groups = ignored_seq_groups
50
51

    def is_empty(self) -> bool:
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
        # NOTE: We do not consider the ignored sequence groups.
        return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
                and not self.blocks_to_swap_out and not self.blocks_to_copy)
55
56


Woosuk Kwon's avatar
Woosuk Kwon committed
57
58
class Scheduler:

Woosuk Kwon's avatar
Woosuk Kwon committed
59
    def __init__(
Woosuk Kwon's avatar
Woosuk Kwon committed
60
        self,
61
62
        scheduler_config: SchedulerConfig,
        cache_config: CacheConfig,
Woosuk Kwon's avatar
Woosuk Kwon committed
63
    ) -> None:
64
65
        self.scheduler_config = scheduler_config
        self.cache_config = cache_config
Woosuk Kwon's avatar
Woosuk Kwon committed
66

67
68
69
        self.prompt_limit = min(self.scheduler_config.max_model_len,
                                self.scheduler_config.max_num_batched_tokens)

70
        # Instantiate the scheduling policy.
71
        self.policy = PolicyFactory.get_policy(policy_name="fcfs")
Woosuk Kwon's avatar
Woosuk Kwon committed
72
        # Create the block space manager.
Woosuk Kwon's avatar
Woosuk Kwon committed
73
        self.block_manager = BlockSpaceManager(
74
75
76
            block_size=self.cache_config.block_size,
            num_gpu_blocks=self.cache_config.num_gpu_blocks,
            num_cpu_blocks=self.cache_config.num_cpu_blocks,
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
        )

79
80
81
        # Sequence groups in the WAITING state.
        self.waiting: List[SequenceGroup] = []
        # Sequence groups in the RUNNING state.
82
        self.running: List[SequenceGroup] = []
83
        # Sequence groups in the SWAPPED state.
Woosuk Kwon's avatar
Woosuk Kwon committed
84
        self.swapped: List[SequenceGroup] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
85

86
    def add_seq_group(self, seq_group: SequenceGroup) -> None:
87
        # Add sequence groups to the waiting queue.
88
        self.waiting.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
89

Antoni Baum's avatar
Antoni Baum committed
90
91
92
93
    def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
        if isinstance(request_id, str):
            request_id = (request_id, )
        request_ids = set(request_id)
94
95
        for state_queue in [self.waiting, self.running, self.swapped]:
            for seq_group in state_queue:
Antoni Baum's avatar
Antoni Baum committed
96
                if seq_group.request_id in request_ids:
97
98
99
100
101
102
                    # Remove the sequence group from the state queue.
                    state_queue.remove(seq_group)
                    for seq in seq_group.seqs:
                        if seq.is_finished():
                            continue
                        self.free_seq(seq, SequenceStatus.FINISHED_ABORTED)
Antoni Baum's avatar
Antoni Baum committed
103
104
105
                    request_ids.remove(seq_group.request_id)
                    if not request_ids:
                        return
106

107
108
109
    def has_unfinished_seqs(self) -> bool:
        return self.waiting or self.running or self.swapped

110
111
112
    def get_num_unfinished_seq_groups(self) -> int:
        return len(self.waiting) + len(self.running) + len(self.swapped)

Woosuk Kwon's avatar
Woosuk Kwon committed
113
    def _schedule(self) -> SchedulerOutputs:
114
115
116
        # Blocks that need to be swaped or copied before model execution.
        blocks_to_swap_in: Dict[int, int] = {}
        blocks_to_swap_out: Dict[int, int] = {}
117
        blocks_to_copy: Dict[int, List[int]] = {}
118

119
120
121
        # Fix the current time.
        now = time.time()

Woosuk Kwon's avatar
Woosuk Kwon committed
122
123
124
125
126
127
128
129
130
131
132
133
        # Join waiting sequences if possible.
        if not self.swapped:
            ignored_seq_groups: List[SequenceGroup] = []
            scheduled: List[SequenceGroup] = []
            num_batched_tokens = 0
            # Optimization: We do not sort the waiting queue since the preempted
            # sequence groups are added to the front and the new sequence groups
            # are added to the back.
            while self.waiting:
                seq_group = self.waiting[0]

                num_prompt_tokens = seq_group.get_seqs()[0].get_len()
134
                if num_prompt_tokens > self.prompt_limit:
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
                    logger.warning(
                        f"Input prompt ({num_prompt_tokens} tokens) is too long"
137
                        f" and exceeds limit of {self.prompt_limit}")
Woosuk Kwon's avatar
Woosuk Kwon committed
138
139
140
141
                    for seq in seq_group.get_seqs():
                        seq.status = SequenceStatus.FINISHED_IGNORED
                    ignored_seq_groups.append(seq_group)
                    self.waiting.pop(0)
142
                    continue
Woosuk Kwon's avatar
Woosuk Kwon committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183

                # If the sequence group cannot be allocated, stop.
                if not self.block_manager.can_allocate(seq_group):
                    break

                # If the number of batched tokens exceeds the limit, stop.
                if (num_batched_tokens + num_prompt_tokens >
                        self.scheduler_config.max_num_batched_tokens):
                    break

                # The total number of sequences in the RUNNING state should not
                # exceed the maximum number of sequences.
                num_new_seqs = seq_group.num_seqs(
                    status=SequenceStatus.WAITING)
                num_curr_seqs = sum(
                    seq_group.num_seqs(status=SequenceStatus.RUNNING)
                    for seq_group in self.running)
                if (num_curr_seqs + num_new_seqs >
                        self.scheduler_config.max_num_seqs):
                    break

                seq_group = self.waiting.pop(0)
                self._allocate(seq_group)
                self.running.append(seq_group)
                num_batched_tokens += num_prompt_tokens
                scheduled.append(seq_group)

            if scheduled:
                scheduler_outputs = SchedulerOutputs(
                    scheduled_seq_groups=scheduled,
                    prompt_run=True,
                    num_batched_tokens=num_batched_tokens,
                    blocks_to_swap_in=blocks_to_swap_in,
                    blocks_to_swap_out=blocks_to_swap_out,
                    blocks_to_copy=blocks_to_copy,
                    ignored_seq_groups=ignored_seq_groups,
                )
                return scheduler_outputs

        # NOTE(woosuk): Preemption happens only when there is no available slot
        # to keep all the sequence groups in the RUNNING state.
184
185
186
187
188
189
190
191
192
        # In this case, the policy is responsible for deciding which sequence
        # groups to preempt.
        self.running = self.policy.sort_by_priority(now, self.running)

        # Reserve new token slots for the running sequence groups.
        running: List[SequenceGroup] = []
        preempted: List[SequenceGroup] = []
        while self.running:
            seq_group = self.running.pop(0)
193
            while not self.block_manager.can_append_slot(seq_group):
194
195
196
197
198
199
200
201
202
203
                if self.running:
                    # Preempt the lowest-priority sequence groups.
                    victim_seq_group = self.running.pop(-1)
                    self._preempt(victim_seq_group, blocks_to_swap_out)
                    preempted.append(victim_seq_group)
                else:
                    # No other sequence groups can be preempted.
                    # Preempt the current sequence group.
                    self._preempt(seq_group, blocks_to_swap_out)
                    preempted.append(seq_group)
Woosuk Kwon's avatar
Woosuk Kwon committed
204
205
                    break
            else:
206
                # Append new slots to the sequence group.
207
                self._append_slot(seq_group, blocks_to_copy)
208
209
210
211
212
                running.append(seq_group)
        self.running = running

        # Swap in the sequence groups in the SWAPPED state if possible.
        self.swapped = self.policy.sort_by_priority(now, self.swapped)
213
        while self.swapped and not blocks_to_swap_out:
214
215
216
217
218
219
            seq_group = self.swapped[0]
            # If the sequence group has been preempted in this step, stop.
            if seq_group in preempted:
                break
            # If the sequence group cannot be swapped in, stop.
            if not self.block_manager.can_swap_in(seq_group):
Woosuk Kwon's avatar
Woosuk Kwon committed
220
221
                break

222
223
            # The total number of sequences in the RUNNING state should not
            # exceed the maximum number of sequences.
224
            num_new_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
225
226
227
            num_curr_seqs = sum(
                seq_group.num_seqs(status=SequenceStatus.RUNNING)
                for seq_group in self.running)
228
229
            if (num_curr_seqs + num_new_seqs >
                    self.scheduler_config.max_num_seqs):
230
231
                break

232
233
            seq_group = self.swapped.pop(0)
            self._swap_in(seq_group, blocks_to_swap_in)
234
            self._append_slot(seq_group, blocks_to_copy)
235
            self.running.append(seq_group)
236

237
238
        num_batched_tokens = sum(
            seq_group.num_seqs(status=SequenceStatus.RUNNING)
239
            for seq_group in self.running)
240

241
        scheduler_outputs = SchedulerOutputs(
Woosuk Kwon's avatar
Woosuk Kwon committed
242
243
244
            scheduled_seq_groups=self.running,
            prompt_run=False,
            num_batched_tokens=num_batched_tokens,
245
246
247
            blocks_to_swap_in=blocks_to_swap_in,
            blocks_to_swap_out=blocks_to_swap_out,
            blocks_to_copy=blocks_to_copy,
Woosuk Kwon's avatar
Woosuk Kwon committed
248
            ignored_seq_groups=[],
249
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
250
        return scheduler_outputs
Woosuk Kwon's avatar
Woosuk Kwon committed
251

Woosuk Kwon's avatar
Woosuk Kwon committed
252
    def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
253
254
255
        # Schedule sequence groups.
        # This function call changes the internal states of the scheduler
        # such as self.running, self.swapped, and self.waiting.
Woosuk Kwon's avatar
Woosuk Kwon committed
256
        scheduler_outputs = self._schedule()
257
258

        # Create input data structures.
259
        seq_group_metadata_list: List[SequenceGroupMetadata] = []
Woosuk Kwon's avatar
Woosuk Kwon committed
260
        for seq_group in scheduler_outputs.scheduled_seq_groups:
261
            seq_data: Dict[int, List[SequenceData]] = {}
262
263
            block_tables: Dict[int, List[int]] = {}
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
264
                seq_id = seq.seq_id
265
                seq_data[seq_id] = seq.data
266
                block_tables[seq_id] = self.block_manager.get_block_table(seq)
267

268
            seq_group_metadata = SequenceGroupMetadata(
269
                request_id=seq_group.request_id,
Woosuk Kwon's avatar
Woosuk Kwon committed
270
                is_prompt=scheduler_outputs.prompt_run,
271
                seq_data=seq_data,
272
                sampling_params=seq_group.sampling_params,
273
274
                block_tables=block_tables,
            )
275
            seq_group_metadata_list.append(seq_group_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
276
        return seq_group_metadata_list, scheduler_outputs
277

278
    def update(
Woosuk Kwon's avatar
Woosuk Kwon committed
279
        self,
280
        seq_outputs: Dict[int, SequenceOutputs],
281
    ) -> List[SequenceGroup]:
Woosuk Kwon's avatar
Woosuk Kwon committed
282
        scheduled: List[SequenceGroup] = []
283
        for seq_group in self.running:
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285
286
287
288
289
290
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
                if seq.seq_id in seq_outputs:
                    scheduled.append(seq_group)
                    break

        # Update the scheduled sequences and free blocks.
        for seq_group in scheduled:
291
292
            # Process beam search results before processing the new tokens.
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
293
294
                output = seq_outputs[seq.seq_id]
                if seq.seq_id != output.parent_seq_id:
295
296
                    # The sequence is a fork of the parent sequence (beam
                    # search). Free the current sequence.
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298
                    self.block_manager.free(seq)
                    # Fork the parent sequence.
299
300
                    parent_seq = seq_group.find(output.parent_seq_id)
                    parent_seq.fork(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
                    self.block_manager.fork(parent_seq, seq)

303
304
            # Process the new tokens.
            for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Woosuk Kwon's avatar
Woosuk Kwon committed
305
                # Append a new token to the sequence.
306
                output = seq_outputs[seq.seq_id]
307
                seq.append_token_id(output.output_token, output.logprobs)
Woosuk Kwon's avatar
Woosuk Kwon committed
308
        return scheduled
Woosuk Kwon's avatar
Woosuk Kwon committed
309

Zhuohan Li's avatar
Zhuohan Li committed
310
311
    def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
        seq.status = finish_status
312
        self.block_manager.free(seq)
Woosuk Kwon's avatar
Woosuk Kwon committed
313

314
315
316
317
318
    def free_finished_seq_groups(self) -> None:
        self.running = [
            seq_group for seq_group in self.running
            if not seq_group.is_finished()
        ]
Woosuk Kwon's avatar
Woosuk Kwon committed
319

320
321
    def _allocate(self, seq_group: SequenceGroup) -> None:
        self.block_manager.allocate(seq_group)
322
        for seq in seq_group.get_seqs():
323
324
            seq.status = SequenceStatus.RUNNING

325
    def _append_slot(
326
327
328
329
330
        self,
        seq_group: SequenceGroup,
        blocks_to_copy: Dict[int, List[int]],
    ) -> None:
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
331
            ret = self.block_manager.append_slot(seq)
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
            if ret is not None:
                src_block, dst_block = ret
                if src_block in blocks_to_copy:
                    blocks_to_copy[src_block].append(dst_block)
                else:
                    blocks_to_copy[src_block] = [dst_block]

    def _preempt(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
        preemption_mode: Optional[PreemptionMode] = None,
    ) -> None:
        # If preemption mode is not specified, we determine the mode as follows:
        # We use recomputation by default since it incurs lower overhead than
        # swapping. However, when the sequence group has multiple sequences
        # (e.g., beam search), recomputation is not supported. In such a case,
        # we use swapping instead.
        # FIXME(woosuk): This makes our scheduling policy a bit bizarre.
        # As swapped sequences are prioritized over waiting sequences,
        # sequence groups with multiple sequences are implicitly prioritized
        # over sequence groups with a single sequence.
        # TODO(woosuk): Support recomputation for sequence groups with multiple
        # sequences. This may require a more sophisticated CUDA kernel.
        if preemption_mode is None:
            seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
            if len(seqs) == 1:
                preemption_mode = PreemptionMode.RECOMPUTE
            else:
                preemption_mode = PreemptionMode.SWAP
        if preemption_mode == PreemptionMode.RECOMPUTE:
            self._preempt_by_recompute(seq_group)
        elif preemption_mode == PreemptionMode.SWAP:
            self._preempt_by_swap(seq_group, blocks_to_swap_out)
        else:
367
            assert False, "Invalid preemption mode."
368
369
370
371
372
373
374
375
376
377

    def _preempt_by_recompute(
        self,
        seq_group: SequenceGroup,
    ) -> None:
        seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
        assert len(seqs) == 1
        for seq in seqs:
            seq.status = SequenceStatus.WAITING
            self.block_manager.free(seq)
378
379
380
        # NOTE: For FCFS, we insert the preempted sequence group to the front
        # of the waiting queue.
        self.waiting.insert(0, seq_group)
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404

    def _preempt_by_swap(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
        self._swap_out(seq_group, blocks_to_swap_out)
        self.swapped.append(seq_group)

    def _swap_in(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_in: Dict[int, int],
    ) -> None:
        mapping = self.block_manager.swap_in(seq_group)
        blocks_to_swap_in.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
            seq.status = SequenceStatus.RUNNING

    def _swap_out(
        self,
        seq_group: SequenceGroup,
        blocks_to_swap_out: Dict[int, int],
    ) -> None:
405
406
407
408
409
410
        if not self.block_manager.can_swap_out(seq_group):
            # FIXME(woosuk): Abort the sequence group instead of aborting the
            # entire engine.
            raise RuntimeError(
                "Aborted due to the lack of CPU swap space. Please increase "
                "the swap space to avoid this error.")
411
412
413
414
        mapping = self.block_manager.swap_out(seq_group)
        blocks_to_swap_out.update(mapping)
        for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
            seq.status = SequenceStatus.SWAPPED