test_scheduler.py 39.6 KB
Newer Older
1
import time
2
from collections import deque
3
from typing import List, Set, Tuple
4
from unittest.mock import MagicMock
5

6
import pytest  # noqa
7
from torch import Use  # noqa
8

9
10
11
12
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
13
from vllm.sequence import SequenceGroup, SequenceStatus
14

15
16
17
from .utils import (append_new_token, append_new_token_seq_group,
                    create_dummy_prompt, get_sequence_groups,
                    schedule_and_update_computed_tokens)
18
19


20
def test_scheduler_add_seq_group():
21
    block_size = 4
22
    scheduler_config = SchedulerConfig(
23
24
25
26
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
27
    )
28
    cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
29
30
31
32
33
34
35
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq group to scheduler.
    num_seq_group = 4
    for i in range(num_seq_group):
36
37
38
        _, seq_group = create_dummy_prompt(str(i),
                                           block_size,
                                           block_size=block_size)
39
40
41
42
        scheduler.add_seq_group(seq_group)
        assert scheduler.get_num_unfinished_seq_groups() == i + 1


43
def test_scheduler_abort_seq_group():
44
    block_size = 4
45
    scheduler_config = SchedulerConfig(
46
47
48
49
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
50
    )
51
52
53
54
55
56
57
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add multiple seq groups to scheduler.
    num_seq_group = 4
58
    request_ids: Set[str] = set()
59
60
61
62
63
64
65
66
67
68
69
    for i in range(num_seq_group):
        _, seq_group = create_dummy_prompt(str(i), block_size)
        scheduler.add_seq_group(seq_group)
        request_ids.add(str(i))

    # Abort all added seq groups.
    assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
    scheduler.abort_seq_group(request_ids)
    assert scheduler.get_num_unfinished_seq_groups() == 0


70
def test_scheduler_schedule_simple():
71
72
73
    block_size = 4
    num_seq_group = 4
    max_model_len = 16
74
    scheduler_config = SchedulerConfig(
75
76
77
78
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=num_seq_group,
        max_model_len=max_model_len,
79
    )
80
81
82
83
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)
84
    running: List[SequenceGroup] = []
85
86
87

    # Add seq groups to scheduler.
    for i in range(num_seq_group):
88
89
90
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
91
92
93
94
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Schedule seq groups prompts.
95
    num_tokens = block_size * num_seq_group
96
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
97
    assert set(get_sequence_groups(out)) == set(running)
98
    assert out.num_batched_tokens == num_tokens
99
100
101
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
102
    append_new_token(out, 1)
103
104

    # Schedule seq groups generation.
105
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
106
    assert set(get_sequence_groups(out)) == set(running)
107
108
109
110
    assert out.num_batched_tokens == num_seq_group
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
111
112
113
    append_new_token(out, 1)


114
def test_scheduler_prefill_prioritized():
115
116
117
118
    """Verify running batched tokens are not applied to prefill requests."""
    block_size = 4
    max_model_len = 30
    max_batched_num_tokens = 30
119
    scheduler_config = SchedulerConfig(
120
121
122
123
        "generate",
        max_num_batched_tokens=max_batched_num_tokens,
        max_num_seqs=2,
        max_model_len=max_model_len,
124
    )
125
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
126
127
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
128
129
130
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
131
    _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size)
132
133
134
135
136
137
138
    scheduler.add_seq_group(seq_group_a)

    # Schedule seq groups prompts.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_a]

    # Add a new prefill request B.
139
    _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size)
140
141
142
143
144
145
    scheduler.add_seq_group(seq_group_b)

    # Verify prefill requests are prioritized. Since max_batched_num_tokens
    # is 1, new prefill request has to be scheduled first.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_b]
146
147


148
def test_scheduler_schedule_preempt_abort():
149
150
    block_size = 4
    max_model_len = 16
151
    scheduler_config = SchedulerConfig(
152
153
154
155
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=2,
        max_model_len=max_model_len,
156
    )
157
158
159
160
161
162
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 2
    cache_config.num_gpu_blocks = 2
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
163
164
165
166
167
168
    seq_a, seq_group_a = create_dummy_prompt("1",
                                             block_size,
                                             block_size=block_size)
    seq_b, seq_group_b = create_dummy_prompt("2",
                                             block_size,
                                             block_size=block_size)
169
170
171
172
    scheduler.add_seq_group(seq_group_a)
    scheduler.add_seq_group(seq_group_b)

    # Schedule seq groups prompts.
173
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
174
    assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
175
    assert out.num_batched_tokens == block_size * 2  # seq_a and seq_b
176
177
178
179
180
181
182
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 2
    assert scheduler.get_num_unfinished_seq_groups() == 2

    # Append "generated" tokens, allowing the sequence to mark prompt tokens as
    # processed.
183
    append_new_token(out, 1)
184
185

    # Schedule seq groups generation and preempt seq group b.
186
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
187
    assert get_sequence_groups(out) == [seq_group_a]
188
189
190
191
192
    assert out.num_batched_tokens == 1
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 2
193
    assert out.preempted == 1
194
195
196

    # Abort seq group a. Re-schedule seq group b prompt with recomputation.
    scheduler.abort_seq_group("1")
197
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
198
    assert get_sequence_groups(out) == [seq_group_b]
199
    assert out.num_batched_tokens == 5  # 4 prompt + 1 generation.
200
201
202
203
204
205
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 1


206
def test_scheduler_max_seqs():
207
208
209
210
    block_size = 4
    num_seq_group = 4
    max_seq_group = 2
    max_model_len = 16
211
    scheduler_config = SchedulerConfig(
212
213
214
215
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=max_seq_group,
        max_model_len=max_model_len,
216
    )
217
218
219
220
221
222
223
224
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    all_seq_groups: List[SequenceGroup] = []
    # Add seq groups to scheduler.
    for i in range(num_seq_group):
225
226
227
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
228
229
230
231
232
233
        all_seq_groups.append(seq_group)

    # Append 1 seq group
    scheduler.add_seq_group(all_seq_groups[0])

    # Schedule seq groups prompts.
234
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
235
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
236
    append_new_token(out, 1)
237
238

    # Schedule seq groups generation.
239
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
240
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
241
    append_new_token(out, 1)
242
243
244
245
246
247
248
249

    # Append 2 more seq group
    scheduler.add_seq_group(all_seq_groups[1])
    scheduler.add_seq_group(all_seq_groups[2])

    # Schedule seq groups prompts.
    # Only 1 seq group should be scheduled since max_seq_group is 2
    # and one is prompting.
250
    _, out = schedule_and_update_computed_tokens(scheduler)
251
    assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
252
253


254
def test_scheduler_delay_factor():
255
    block_size = 4
256
    scheduler_config = SchedulerConfig(
257
258
259
260
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=16,
261
        delay_factor=0.5,
262
    )
263
264
265
266
267
268
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # schedule first prompt
269
    seq_group_meta, seq_group = create_dummy_prompt("0",
270
271
                                                    prompt_length=block_size,
                                                    block_size=block_size)
272
    scheduler.add_seq_group(seq_group)
273
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
274
    assert out.num_prefill_groups > 0
275
    assert seq_group_meta[0].request_id == '0'
276
    append_new_token(out, 1)
277
278
279

    # wait for a second before scheduling next prompt
    time.sleep(1)
280
    seq_group_meta, seq_group = create_dummy_prompt("1",
281
282
                                                    prompt_length=block_size,
                                                    block_size=block_size)
283
284
285
    scheduler.add_seq_group(seq_group)

    # second prompt should *not* be scheduled
286
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
287
    assert out.num_prefill_groups == 0
288
    assert seq_group_meta[0].request_id == '0'
289
    append_new_token(out, 1)
290
291
292

    # wait for more than 0.5 second and try again
    time.sleep(0.6)
293
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
294
    assert out.num_prefill_groups > 0
295
    assert seq_group_meta[0].request_id == '1'
296
    append_new_token(out, 1)
297
298


299
def test_swapped_out_prioritized():
300
301
302
303
304
    block_size = 4
    scheduler = initialize_scheduler(max_num_seqs=6,
                                     block_size=block_size,
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
305
306
    # best_of=2 * 3 == 6 sequences.
    for i in range(3):
307
308
309
310
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           best_of=2,
                                           block_size=block_size)
311
        scheduler.add_seq_group(seq_group)
312
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
313
314
    # prefill scheduled now.
    assert len(out.scheduled_seq_groups) == 3
315
    append_new_token(out, 1)
316
317
318
319
320
321
322
323
324
325

    # The last request should be swapped out.
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group(seq_group, num_lookahead_slots):
        return seq_group.request_id != "2"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group)

326
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
327
328
    assert len(out.scheduled_seq_groups) == 2
    assert out.num_batched_tokens == 2
329
330
    assert out.blocks_to_swap_out != []
    assert out.blocks_to_swap_in == []
331
    append_new_token(out, 1)
332
333

    # Add 1 more task. Swap should be prioritized over prefill.
334
335
336
337
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       best_of=2,
                                       block_size=block_size)
338
    scheduler.add_seq_group(seq_group)
339
340
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    append_new_token(out, 1)
341
342
343
    assert len(out.scheduled_seq_groups) == 3
    # 3 decodes. It is swapped in.
    assert out.num_batched_tokens == 3
344
345
    assert out.blocks_to_swap_in != []
    assert out.blocks_to_swap_out == []
346
347


348
349
350
351
352
353
354
355
356
357
358
359
def initialize_scheduler(
    *,
    max_num_seqs=1000,
    max_token_budget=1000,
    max_model_len=1000,
    lora_config=None,
    block_size=4,
    num_cpu_blocks=8,
    num_gpu_blocks=8,
):
    block_size = block_size
    scheduler_config = SchedulerConfig(
360
361
362
363
        "generate",
        max_num_batched_tokens=max_token_budget,
        max_num_seqs=max_num_seqs,
        max_model_len=max_model_len,
364
    )
365
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
366
367
    cache_config.num_cpu_blocks = num_cpu_blocks
    cache_config.num_gpu_blocks = num_gpu_blocks
368
369
370
371
    scheduler = Scheduler(scheduler_config, cache_config, lora_config)
    return scheduler


372
def create_token_budget(token_budget: int = 10000,
373
374
375
376
377
378
379
                        max_num_seqs: int = 10000) -> SchedulingBudget:
    return SchedulingBudget(
        token_budget=token_budget,
        max_num_seqs=max_num_seqs,
    )


380
381
382
383
384
385
386
387
388
def add_token_budget(budget: SchedulingBudget,
                     num_batched_tokens: int = 0,
                     num_curr_seqs: int = 0):
    mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
    budget.add_num_batched_tokens(mock_seq_group.request_id,
                                  num_batched_tokens)
    budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)


389
def test_prefill_schedule_max_prompt_len():
390
391
392
    """
    Test prompt longer than max_prompt_len is aborted.
    """
393
    block_size = 4
394
    scheduler = initialize_scheduler(max_model_len=30, block_size=block_size)
395
396
397
    _, seq_group = create_dummy_prompt("0",
                                       prompt_length=60,
                                       block_size=block_size)
398
    scheduler.add_seq_group(seq_group)
399
    budget = create_token_budget()
400
401
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
402
403
404
405
406
407
408
    assert len(output.ignored_seq_groups) == 1
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


409
def test_prefill_schedule_token_budget():
410
411
412
    """
    Test token budget respected.
    """
413
    block_size = 4
414
    scheduler = initialize_scheduler(block_size=block_size,
415
416
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
417
418
    budget = create_token_budget(token_budget=0)
    for i in range(2):
419
420
421
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
422
        scheduler.add_seq_group(seq_group)
423
424

    # 0 token budget == nothing is scheduled.
425
426
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
427
428
429
430
431
432
433
434
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 2

    # 60 token budget == 1 request scheduled.
    budget = create_token_budget(token_budget=60)
435
436
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
437
438
439
440
441
442
443
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 60
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 1

    # Test when current_batched_tokens respected.
444
    scheduler = initialize_scheduler(block_size=block_size,
445
446
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
447
448
    budget = create_token_budget(token_budget=60)
    add_token_budget(budget, 30, 0)
449
450
451
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
452
    # Cannot schedule a prompt that doesn't fit the budget.
453
454
455
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
456
457
458
459
460
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 30
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 1
461
462
    budget = create_token_budget(token_budget=90)
    add_token_budget(budget, 30, 0)
463
464
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
465
466
467
468
469
470
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 90
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 0


471
def test_prefill_schedule_max_seqs():
472
473
474
    """
    Test max seq respected.
    """
475
    block_size = 4
476
    scheduler = initialize_scheduler(block_size=block_size,
477
478
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
479
480
    budget = create_token_budget(max_num_seqs=2)
    for i in range(3):
481
482
483
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
484
485
486
        scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
487
488
489
490
491
492
493
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1

    # Verify curr_num_seqs respected.
494
    scheduler.waiting = deque()
495
496
    budget = create_token_budget(max_num_seqs=2)
    add_token_budget(budget, 0, 2)
497
498
499
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
500
501
502
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
503
504
505
506
507
508
509
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1


510
def test_prefill_schedule_max_lora():
511
512
513
    """
    Test max lora is respected and prioritized.
    """
514
    block_size = 4
515
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
516
517
518
519
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
520
    budget = create_token_budget(token_budget=120)
521
    curr_loras: Set[int] = set()
522
523
524
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
525
                                           block_size=block_size,
526
527
528
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
529
                                               lora_path="abc"))
530
        scheduler.add_seq_group(seq_group)
531
532
533
534
535
536
    # Add two more requests to verify lora is prioritized.
    # 0: Lora, 1: Lora, 2: regular, 3: regular
    # In the first iteration, index 0, 2 is scheduled.
    # If a request is not scheduled because it hits max lora, it is
    # prioritized. Verify that.
    for i in range(2, 4):
537
538
539
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
540
        scheduler.add_seq_group(seq_group)
541
    # Schedule 2 requests (0 and 2)
542
543
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
544
545
546
547
548
549
550
551
552
553
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 2
    assert len(curr_loras) == 1
    # The second lora request is scheduled next as FCFS policy.
    # Reset curr_loras so that it can be scheduled.
    curr_loras = set()
    budget = create_token_budget(token_budget=60)
554
555
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
556
557
558
559
560
561
562
    assert len(output.seq_groups) == 1
    assert output.seq_groups[0].seq_group.request_id == "1"
    assert len(remaining_waiting) == 1
    assert len(curr_loras) == 1
    assert budget.num_batched_tokens == 60


563
def test_prefill_schedule_no_block_manager_capacity():
564
565
566
    """
    Test sequence cannot be scheduled due to block manager has no capacity.
    """
567
    block_size = 4
568
    scheduler = initialize_scheduler(block_size=block_size,
569
570
                                     num_gpu_blocks=128,
                                     num_cpu_blocks=128)
571
572
    budget = create_token_budget()
    for i in range(3):
573
574
575
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
576
        scheduler.add_seq_group(seq_group)
577
578
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
579
580
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
581
582
583
584
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
585
    assert len(remaining_waiting) == 3
586
587
588
589

    scheduler = initialize_scheduler()
    budget = create_token_budget()
    for i in range(3):
590
591
592
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
593
        scheduler.add_seq_group(seq_group)
594
595
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
596
597
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
598
599
600
601
602
603
604
    assert len(output.ignored_seq_groups) == 3
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


605
def test_decode_schedule_preempted():
606
607
608
    """
    Test decodes cannot be scheduled and preempted.
    """
609
    block_size = 4
610
    scheduler = initialize_scheduler(block_size=block_size,
611
612
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
613
614
    curr_loras = None
    for i in range(3):
615
616
617
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
618
        scheduler._allocate_and_set_running(seq_group)
619
        append_new_token_seq_group(60, seq_group, 1)
620
        scheduler._add_seq_group_to_running(seq_group)
621
622
623
624
625
626
627
628
629
630
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group(seq_group, num_lookahead_slots):
        return seq_group.request_id != "1"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group)

    # 1 cannot be scheduled, and the lowest priority (request 2)
    # should be preempted. 1 will also be preempted.
631
    budget = create_token_budget()
632
633
    output = scheduler._schedule_running(budget, curr_loras)
    remainig_running = scheduler.running
634
    assert len(remainig_running) == 0
635
636
637
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
    assert output.decode_seq_groups[0].seq_group.request_id == "0"
638
639
640
    assert len(output.preempted) == 2
    # Verify budgets are updated.
    assert budget.num_batched_tokens == 1
641
642
    # NOTE: When enable_chunk is False, num_seqs budget is not updated.
    # assert budget.num_curr_seqs == 1
643
    # Both should be preempted, not swapped.
644
    assert output.blocks_to_swap_out == []
645
    # Nothing is copied.
646
    assert output.blocks_to_copy == []
647
648


649
def test_decode_swap_beam_search():
650
651
652
    """
    Test best_of > 1 swap out blocks
    """
653
    block_size = 4
654
    scheduler = initialize_scheduler(block_size=block_size,
655
656
                                     num_gpu_blocks=64,
                                     num_cpu_blocks=64)
657
    curr_loras = None
658
    budget = create_token_budget()
659
    for i in range(3):
660
661
662
663
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           best_of=2,
                                           block_size=block_size)
664
        scheduler._allocate_and_set_running(seq_group)
665
        scheduler._add_seq_group_to_running(seq_group)
666
667
668
669
670
        append_new_token_seq_group(60, seq_group, 1)
        budget.add_num_seqs(seq_group.request_id,
                            seq_group.get_max_num_running_seqs())
        budget.add_num_batched_tokens(
            seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING))
671
672
673
674
675
676
677
678
679
680

    # The last request should be swapped out.
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group(seq_group, num_lookahead_slots):
        return seq_group.request_id != "2"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group)
    scheduler.block_manager.swap_out = MagicMock()
681
    expected_swap_mapping = [("5", "7")]
682
683
    scheduler.block_manager.swap_out.return_value = expected_swap_mapping

684
685
    output = scheduler._schedule_running(budget, curr_loras)
    remainig_running = scheduler.running
686
    assert len(remainig_running) == 0
687
688
689
690
    assert len(output.decode_seq_groups) == 2
    assert len(output.prefill_seq_groups) == 0
    assert output.decode_seq_groups[0].seq_group.request_id == "0"
    assert output.decode_seq_groups[1].seq_group.request_id == "1"
691
692
693
694
695
    assert len(output.preempted) == 0
    assert len(output.swapped_out) == 1
    # Budget should refledct preempted requests.
    assert budget.num_batched_tokens == 2
    # since there are 2 sequences, 2 should be subtracted.
696
    assert budget.num_curr_seqs == 4
697
698
699
    # Both should be preempted, not swapped.
    assert output.blocks_to_swap_out == expected_swap_mapping
    # Nothing is copied.
700
    assert output.blocks_to_copy == []
701
702


703
def test_schedule_decode_blocks_to_copy_update():
704
705
706
    """
    Verify blocks_to_copy is updated.
    """
707
    block_size = 4
708
    scheduler = initialize_scheduler(block_size=4,
709
710
711
712
713
714
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       best_of=2,
                                       block_size=block_size)
715
    curr_loras = None
716
    scheduler._allocate_and_set_running(seq_group)
717
    append_new_token_seq_group(60, seq_group, 1)
718
    scheduler._add_seq_group_to_running(seq_group)
719
720
721

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
722
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
723
724

    budget = create_token_budget()
725
726
    output = scheduler._schedule_running(budget, curr_loras)
    remaining_running = scheduler.running
727
    assert len(remaining_running) == 0
728
729
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
730
731
732
    assert len(output.preempted) == 0
    assert len(output.swapped_out) == 0
    # Nothing is preempted.
733
    assert output.blocks_to_swap_out == []
734
735
    # Since append_slot returns the source -> dist mapping, it should
    # applied.
736
    assert output.blocks_to_copy == [(2, 3)]
737
738


739
def test_schedule_swapped_simple():
740
    block_size = 4
741
    scheduler = initialize_scheduler(block_size=block_size)
742
    curr_loras = None
743
    blocks_to_swap_out: List[Tuple[int, int]] = []
744
745
746
747
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=4,
                                       best_of=2,
                                       block_size=block_size)
748
    scheduler._allocate_and_set_running(seq_group)
749
    append_new_token_seq_group(4, seq_group, 1)
750
    scheduler._swap_out(seq_group, blocks_to_swap_out)
751
    scheduler._add_seq_group_to_swapped(seq_group)
752
753

    budget = create_token_budget()
754
755
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
756
757
758
    assert len(remaining_swapped) == 0
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 2
759
760
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
761
    # swap in is the reverse of swap out
762
763
764
    blocks_to_swap_in_reverse = []
    for swapin, swapout in output.blocks_to_swap_in:
        blocks_to_swap_in_reverse.append((swapout, swapin))
765
766
767
    assert blocks_to_swap_out == blocks_to_swap_in_reverse


768
def test_schedule_swapped_max_token_budget():
769
    block_size = 4
770
    scheduler = initialize_scheduler(block_size=block_size,
771
772
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
773
    curr_loras = None
774
    blocks_to_swap_out: List[Tuple[int, int]] = []
775
776
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2)
777
        scheduler._allocate_and_set_running(seq_group)
778
        append_new_token_seq_group(60, seq_group, 1)
779
        scheduler._swap_out(seq_group, blocks_to_swap_out)
780
        scheduler._add_seq_group_to_swapped(seq_group)
781
782

    budget = create_token_budget(token_budget=1)
783
784
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
785
786
787
    assert len(remaining_swapped) == 1
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 2
788
789
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
790
791

    # Verify num_batched_tokens are respected.
792
793
    budget = create_token_budget(token_budget=1)
    add_token_budget(budget, 1, 0)
794
795
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
796
797
798
    assert len(remaining_swapped) == 1
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 0
799
800
    assert len(output.decode_seq_groups) == 0
    assert len(output.prefill_seq_groups) == 0
801
802


803
def test_schedule_swapped_max_seqs():
804
    block_size = 4
805
    scheduler = initialize_scheduler(block_size=block_size,
806
807
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
808
    curr_loras = None
809
    blocks_to_swap_out: List[Tuple[int, int]] = []
810
    for i in range(4):
811
812
813
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=4)
814
        scheduler._allocate_and_set_running(seq_group)
815
        append_new_token_seq_group(60, seq_group, 1)
816
        scheduler._swap_out(seq_group, blocks_to_swap_out)
817
        scheduler._add_seq_group_to_swapped(seq_group)
818
819

    budget = create_token_budget(max_num_seqs=2)
820
821
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
822
823
    assert len(remaining_swapped) == 2
    assert budget.num_batched_tokens == 2
824
    assert budget.num_curr_seqs == 2
825
826
    assert len(output.decode_seq_groups) == 2
    assert len(output.prefill_seq_groups) == 0
827
828

    # Verify num_curr_seqs are respected.
829
830
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
831
832
    assert len(remaining_swapped) == 2
    assert budget.num_batched_tokens == 2
833
    assert budget.num_curr_seqs == 2
834
835
    assert len(output.decode_seq_groups) == 0
    assert len(output.prefill_seq_groups) == 0
836
837


838
def test_schedule_swapped_max_loras():
839
    block_size = 4
840
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
841
842
843
844
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
845
846
    curr_loras: Set[int] = set()
    blocks_to_swap_out: List[Tuple[int, int]] = []
847
848
849
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
850
                                           block_size=block_size,
851
852
853
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
854
                                               lora_path="abc"))
855
        scheduler._allocate_and_set_running(seq_group)
856
        append_new_token_seq_group(60, seq_group, 1)
857
        scheduler._swap_out(seq_group, blocks_to_swap_out)
858
        scheduler._add_seq_group_to_swapped(seq_group)
859
860

    budget = create_token_budget()
861
862
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
863
864
865
    assert len(remaining_swapped) == 1
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 1
866
867
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
868
869
870
    assert len(curr_loras) == 1


871
def test_schedule_swapped_cannot_swap_in():
872
    block_size = 4
873
    scheduler = initialize_scheduler(block_size=block_size,
874
875
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
876
    curr_loras = None
877
    blocks_to_swap_out: List[Tuple[int, int]] = []
878
879
880
881
882
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           best_of=2,
                                           block_size=block_size)
883
        scheduler._allocate_and_set_running(seq_group)
884
        append_new_token_seq_group(60, seq_group, 1)
885
        scheduler._swap_out(seq_group, blocks_to_swap_out)
886
        scheduler._add_seq_group_to_swapped(seq_group)
887
888
889

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
890
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
891
892
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
893
894
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
895
896
897
    assert len(remaining_swapped) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
898
    assert len(output.decode_seq_groups) == 0
899
900
901
    assert len(output.prefill_seq_groups) == 0


902
def test_infeasible_swap():
903
    block_size = 4
904
    scheduler = initialize_scheduler(block_size=block_size,
905
906
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
907
    curr_loras = None
908
    blocks_to_swap_out: List[Tuple[int, int]] = []
909
910
911
912
913
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           best_of=2,
                                           block_size=block_size)
914
915
916
        scheduler._allocate_and_set_running(seq_group)
        append_new_token_seq_group(60, seq_group, 1)
        scheduler._swap_out(seq_group, blocks_to_swap_out)
917
        scheduler._add_seq_group_to_swapped(seq_group)
918
919
920
921
922
923

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
924
925
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
926
927
928
929
930
    assert len(remaining_swapped) == 0
    assert len(output.infeasible_seq_groups) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(output.decode_seq_groups) == 0
931
    assert len(output.prefill_seq_groups) == 0
932
933


934
def test_schedule_swapped_blocks_to_copy():
935
    block_size = 4
936
    scheduler = initialize_scheduler(block_size=block_size,
937
938
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
939
    curr_loras = None
940
941
942
943
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       best_of=2,
                                       block_size=block_size)
944
    scheduler._allocate_and_set_running(seq_group)
945
    append_new_token_seq_group(60, seq_group, 1)
946
    blocks_to_swap_out: List[Tuple[int, int]] = []
947
    scheduler._swap_out(seq_group, blocks_to_swap_out)
948
    scheduler._add_seq_group_to_swapped(seq_group)
949
950
951

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
952
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
953
954

    budget = create_token_budget()
955
956
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
957
    assert len(remaining_swapped) == 0
958
959
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
960
    assert output.blocks_to_copy == [(2, 3)]
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004


def test_scheduling_budget():
    TOKEN_BUDGET = 4
    MAX_SEQS = 4
    budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
    assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
    assert budget.remaining_token_budget() == TOKEN_BUDGET

    # Verify add/subtract num batched tokens.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
    # Verify adding another seq group is no-op.
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0

    # Verify add/subtract max seqs.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
    assert budget.num_curr_seqs == 2
    # Verify adding another seq group is no-op.
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 2
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0