test_scheduler.py 37.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
5
from collections import deque
from unittest.mock import MagicMock
6

7
import pytest  # noqa
8
from torch import Use  # noqa
9

10
11
12
13
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
14
from vllm.sequence import SequenceGroup
15

16
17
18
from .utils import (append_new_token, append_new_token_seq,
                    append_new_token_seq_group, create_dummy_prompt,
                    get_sequence_groups, schedule_and_update_computed_tokens)
19
20


21
def test_scheduler_add_seq_group():
22
    block_size = 4
23
    scheduler_config = SchedulerConfig(
24
25
26
27
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
28
    )
29
    cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
30
31
32
33
34
35
36
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq group to scheduler.
    num_seq_group = 4
    for i in range(num_seq_group):
37
38
39
        _, seq_group = create_dummy_prompt(str(i),
                                           block_size,
                                           block_size=block_size)
40
41
42
43
        scheduler.add_seq_group(seq_group)
        assert scheduler.get_num_unfinished_seq_groups() == i + 1


44
def test_scheduler_abort_seq_group():
45
    block_size = 4
46
    scheduler_config = SchedulerConfig(
47
48
49
50
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
51
    )
52
53
54
55
56
57
58
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add multiple seq groups to scheduler.
    num_seq_group = 4
59
    request_ids: set[str] = set()
60
61
62
63
64
65
66
67
68
69
70
    for i in range(num_seq_group):
        _, seq_group = create_dummy_prompt(str(i), block_size)
        scheduler.add_seq_group(seq_group)
        request_ids.add(str(i))

    # Abort all added seq groups.
    assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
    scheduler.abort_seq_group(request_ids)
    assert scheduler.get_num_unfinished_seq_groups() == 0


71
def test_scheduler_schedule_simple():
72
73
74
    block_size = 4
    num_seq_group = 4
    max_model_len = 16
75
    scheduler_config = SchedulerConfig(
76
77
78
79
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=num_seq_group,
        max_model_len=max_model_len,
80
    )
81
82
83
84
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)
85
    running: list[SequenceGroup] = []
86
87
88

    # Add seq groups to scheduler.
    for i in range(num_seq_group):
89
90
91
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
92
93
94
95
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Schedule seq groups prompts.
96
    num_tokens = block_size * num_seq_group
97
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
98
    assert set(get_sequence_groups(out)) == set(running)
99
    assert out.num_batched_tokens == num_tokens
100
101
102
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
103
    append_new_token(out, 1)
104
105

    # Schedule seq groups generation.
106
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
107
    assert set(get_sequence_groups(out)) == set(running)
108
109
110
111
    assert out.num_batched_tokens == num_seq_group
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
112
113
114
    append_new_token(out, 1)


115
def test_scheduler_prefill_prioritized():
116
117
118
119
    """Verify running batched tokens are not applied to prefill requests."""
    block_size = 4
    max_model_len = 30
    max_batched_num_tokens = 30
120
    scheduler_config = SchedulerConfig(
121
122
123
124
        "generate",
        max_num_batched_tokens=max_batched_num_tokens,
        max_num_seqs=2,
        max_model_len=max_model_len,
125
    )
126
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
127
128
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
129
130
131
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
132
    _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size)
133
134
135
136
137
138
139
    scheduler.add_seq_group(seq_group_a)

    # Schedule seq groups prompts.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_a]

    # Add a new prefill request B.
140
    _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size)
141
142
143
144
145
146
    scheduler.add_seq_group(seq_group_b)

    # Verify prefill requests are prioritized. Since max_batched_num_tokens
    # is 1, new prefill request has to be scheduled first.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_b]
147
148


149
def test_scheduler_schedule_preempt_abort():
150
151
    block_size = 4
    max_model_len = 16
152
    scheduler_config = SchedulerConfig(
153
154
155
156
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=2,
        max_model_len=max_model_len,
157
    )
158
159
160
161
162
163
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 2
    cache_config.num_gpu_blocks = 2
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
164
165
166
167
168
169
    seq_a, seq_group_a = create_dummy_prompt("1",
                                             block_size,
                                             block_size=block_size)
    seq_b, seq_group_b = create_dummy_prompt("2",
                                             block_size,
                                             block_size=block_size)
170
171
172
173
    scheduler.add_seq_group(seq_group_a)
    scheduler.add_seq_group(seq_group_b)

    # Schedule seq groups prompts.
174
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
175
    assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
176
    assert out.num_batched_tokens == block_size * 2  # seq_a and seq_b
177
178
179
180
181
182
183
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 2
    assert scheduler.get_num_unfinished_seq_groups() == 2

    # Append "generated" tokens, allowing the sequence to mark prompt tokens as
    # processed.
184
    append_new_token(out, 1)
185
186

    # Schedule seq groups generation and preempt seq group b.
187
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
188
    assert get_sequence_groups(out) == [seq_group_a]
189
190
191
192
193
    assert out.num_batched_tokens == 1
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 2
194
    assert out.preempted == 1
195
196
197

    # Abort seq group a. Re-schedule seq group b prompt with recomputation.
    scheduler.abort_seq_group("1")
198
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
199
    assert get_sequence_groups(out) == [seq_group_b]
200
    assert out.num_batched_tokens == 5  # 4 prompt + 1 generation.
201
202
203
204
205
206
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 1


207
def test_scheduler_max_seqs():
208
209
210
211
    block_size = 4
    num_seq_group = 4
    max_seq_group = 2
    max_model_len = 16
212
    scheduler_config = SchedulerConfig(
213
214
215
216
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=max_seq_group,
        max_model_len=max_model_len,
217
    )
218
219
220
221
222
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

223
    all_seq_groups: list[SequenceGroup] = []
224
225
    # Add seq groups to scheduler.
    for i in range(num_seq_group):
226
227
228
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
229
230
231
232
233
234
        all_seq_groups.append(seq_group)

    # Append 1 seq group
    scheduler.add_seq_group(all_seq_groups[0])

    # Schedule seq groups prompts.
235
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
236
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
237
    append_new_token(out, 1)
238
239

    # Schedule seq groups generation.
240
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
241
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
242
    append_new_token(out, 1)
243
244
245
246
247
248
249
250

    # Append 2 more seq group
    scheduler.add_seq_group(all_seq_groups[1])
    scheduler.add_seq_group(all_seq_groups[2])

    # Schedule seq groups prompts.
    # Only 1 seq group should be scheduled since max_seq_group is 2
    # and one is prompting.
251
    _, out = schedule_and_update_computed_tokens(scheduler)
252
    assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
253
254


255
def test_scheduler_delay_factor():
256
    block_size = 4
257
    scheduler_config = SchedulerConfig(
258
259
260
261
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=16,
262
        delay_factor=0.5,
263
    )
264
265
266
267
268
269
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # schedule first prompt
270
    seq_group_meta, seq_group = create_dummy_prompt("0",
271
272
                                                    prompt_length=block_size,
                                                    block_size=block_size)
273
    scheduler.add_seq_group(seq_group)
274
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
275
    assert out.num_prefill_groups > 0
276
    assert seq_group_meta[0].request_id == '0'
277
    append_new_token(out, 1)
278
279
280

    # wait for a second before scheduling next prompt
    time.sleep(1)
281
    seq_group_meta, seq_group = create_dummy_prompt("1",
282
283
                                                    prompt_length=block_size,
                                                    block_size=block_size)
284
285
286
    scheduler.add_seq_group(seq_group)

    # second prompt should *not* be scheduled
287
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
288
    assert out.num_prefill_groups == 0
289
    assert seq_group_meta[0].request_id == '0'
290
    append_new_token(out, 1)
291
292
293

    # wait for more than 0.5 second and try again
    time.sleep(0.6)
294
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
295
    assert out.num_prefill_groups > 0
296
    assert seq_group_meta[0].request_id == '1'
297
    append_new_token(out, 1)
298
299


300
301
302
303
304
305
306
307
308
def initialize_scheduler(
    *,
    max_num_seqs=1000,
    max_token_budget=1000,
    max_model_len=1000,
    lora_config=None,
    block_size=4,
    num_cpu_blocks=8,
    num_gpu_blocks=8,
309
310
    enable_prefix_caching=False,
    enable_chunked_prefill=False,
311
312
313
):
    block_size = block_size
    scheduler_config = SchedulerConfig(
314
315
316
317
        "generate",
        max_num_batched_tokens=max_token_budget,
        max_num_seqs=max_num_seqs,
        max_model_len=max_model_len,
318
319
320
321
322
323
324
325
        enable_chunked_prefill=enable_chunked_prefill,
    )
    cache_config = CacheConfig(
        block_size,
        1.0,
        1,
        "auto",
        enable_prefix_caching=enable_prefix_caching,
326
    )
327
328
    cache_config.num_cpu_blocks = num_cpu_blocks
    cache_config.num_gpu_blocks = num_gpu_blocks
329
330
331
332
    scheduler = Scheduler(scheduler_config, cache_config, lora_config)
    return scheduler


333
def create_token_budget(token_budget: int = 10000,
334
335
336
337
338
339
340
                        max_num_seqs: int = 10000) -> SchedulingBudget:
    return SchedulingBudget(
        token_budget=token_budget,
        max_num_seqs=max_num_seqs,
    )


341
342
343
344
345
346
347
348
349
def add_token_budget(budget: SchedulingBudget,
                     num_batched_tokens: int = 0,
                     num_curr_seqs: int = 0):
    mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
    budget.add_num_batched_tokens(mock_seq_group.request_id,
                                  num_batched_tokens)
    budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)


350
def test_prefill_schedule_max_prompt_len():
351
352
353
    """
    Test prompt longer than max_prompt_len is aborted.
    """
354
    block_size = 4
355
    scheduler = initialize_scheduler(max_model_len=30, block_size=block_size)
356
357
358
    _, seq_group = create_dummy_prompt("0",
                                       prompt_length=60,
                                       block_size=block_size)
359
    scheduler.add_seq_group(seq_group)
360
    budget = create_token_budget()
361
362
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
363
364
365
366
367
368
369
    assert len(output.ignored_seq_groups) == 1
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


370
def test_prefill_schedule_token_budget():
371
372
373
    """
    Test token budget respected.
    """
374
    block_size = 4
375
    scheduler = initialize_scheduler(block_size=block_size,
376
377
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
378
379
    budget = create_token_budget(token_budget=0)
    for i in range(2):
380
381
382
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
383
        scheduler.add_seq_group(seq_group)
384
385

    # 0 token budget == nothing is scheduled.
386
387
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
388
389
390
391
392
393
394
395
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 2

    # 60 token budget == 1 request scheduled.
    budget = create_token_budget(token_budget=60)
396
397
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
398
399
400
401
402
403
404
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 60
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 1

    # Test when current_batched_tokens respected.
405
    scheduler = initialize_scheduler(block_size=block_size,
406
407
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
408
409
    budget = create_token_budget(token_budget=60)
    add_token_budget(budget, 30, 0)
410
411
412
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
413
    # Cannot schedule a prompt that doesn't fit the budget.
414
415
416
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
417
418
419
420
421
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 30
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 1
422
423
    budget = create_token_budget(token_budget=90)
    add_token_budget(budget, 30, 0)
424
425
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
426
427
428
429
430
431
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 90
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 0


432
def test_prefill_schedule_max_seqs():
433
434
435
    """
    Test max seq respected.
    """
436
    block_size = 4
437
    scheduler = initialize_scheduler(block_size=block_size,
438
439
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
440
441
    budget = create_token_budget(max_num_seqs=2)
    for i in range(3):
442
443
444
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
445
446
447
        scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
448
449
450
451
452
453
454
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1

    # Verify curr_num_seqs respected.
455
    scheduler.waiting = deque()
456
457
    budget = create_token_budget(max_num_seqs=2)
    add_token_budget(budget, 0, 2)
458
459
460
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
461
462
463
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
464
465
466
467
468
469
470
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1


471
def test_prefill_schedule_max_lora():
472
473
474
    """
    Test max lora is respected and prioritized.
    """
475
    block_size = 4
476
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
477
478
479
480
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
481
    budget = create_token_budget(token_budget=120)
482
    curr_loras: set[int] = set()
483
484
485
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
486
                                           block_size=block_size,
487
488
489
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
490
                                               lora_path="abc"))
491
        scheduler.add_seq_group(seq_group)
492
    # Add two more requests to verify lora is prioritized.
493
    # 0: LoRA, 1: LoRA, 2: regular, 3: regular
494
495
496
497
    # In the first iteration, index 0, 2 is scheduled.
    # If a request is not scheduled because it hits max lora, it is
    # prioritized. Verify that.
    for i in range(2, 4):
498
499
500
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
501
        scheduler.add_seq_group(seq_group)
502
    # Schedule 2 requests (0 and 2)
503
504
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
505
506
507
508
509
510
511
512
513
514
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 2
    assert len(curr_loras) == 1
    # The second lora request is scheduled next as FCFS policy.
    # Reset curr_loras so that it can be scheduled.
    curr_loras = set()
    budget = create_token_budget(token_budget=60)
515
516
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
517
518
519
520
521
522
523
    assert len(output.seq_groups) == 1
    assert output.seq_groups[0].seq_group.request_id == "1"
    assert len(remaining_waiting) == 1
    assert len(curr_loras) == 1
    assert budget.num_batched_tokens == 60


524
def test_prefill_schedule_no_block_manager_capacity():
525
526
527
    """
    Test sequence cannot be scheduled due to block manager has no capacity.
    """
528
    block_size = 4
529
    scheduler = initialize_scheduler(block_size=block_size,
530
531
                                     num_gpu_blocks=128,
                                     num_cpu_blocks=128)
532
533
    budget = create_token_budget()
    for i in range(3):
534
535
536
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
537
        scheduler.add_seq_group(seq_group)
538
539
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
540
541
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
542
543
544
545
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
546
    assert len(remaining_waiting) == 3
547
548
549
550

    scheduler = initialize_scheduler()
    budget = create_token_budget()
    for i in range(3):
551
552
553
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
554
        scheduler.add_seq_group(seq_group)
555
556
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
557
558
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
559
560
561
562
563
564
565
    assert len(output.ignored_seq_groups) == 3
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


566
def test_decode_schedule_preempted():
567
568
569
    """
    Test decodes cannot be scheduled and preempted.
    """
570
    block_size = 4
571
    scheduler = initialize_scheduler(block_size=block_size,
572
573
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
574
575
    curr_loras = None
    for i in range(3):
576
577
578
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
579
        scheduler._allocate_and_set_running(seq_group)
580
        append_new_token_seq_group(60, seq_group, 1)
581
        scheduler._add_seq_group_to_running(seq_group)
582
583
584
585
586
587
588
589
590
591
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group(seq_group, num_lookahead_slots):
        return seq_group.request_id != "1"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group)

    # 1 cannot be scheduled, and the lowest priority (request 2)
    # should be preempted. 1 will also be preempted.
592
    budget = create_token_budget()
593
594
    output = scheduler._schedule_running(budget, curr_loras)
    remainig_running = scheduler.running
595
    assert len(remainig_running) == 0
596
597
598
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
    assert output.decode_seq_groups[0].seq_group.request_id == "0"
599
600
601
    assert len(output.preempted) == 2
    # Verify budgets are updated.
    assert budget.num_batched_tokens == 1
602
603
    # NOTE: When enable_chunk is False, num_seqs budget is not updated.
    # assert budget.num_curr_seqs == 1
604
    # Both should be preempted, not swapped.
605
    assert output.blocks_to_swap_out == []
606
    # Nothing is copied.
607
    assert output.blocks_to_copy == []
608
609


610
def test_schedule_decode_blocks_to_copy_update():
611
612
613
    """
    Verify blocks_to_copy is updated.
    """
614
    block_size = 4
615
    scheduler = initialize_scheduler(block_size=4,
616
617
618
619
620
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       block_size=block_size)
621
    curr_loras = None
622
    scheduler._allocate_and_set_running(seq_group)
623
    append_new_token_seq_group(60, seq_group, 1)
624
    scheduler._add_seq_group_to_running(seq_group)
625
626
627

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
628
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
629
630

    budget = create_token_budget()
631
632
    output = scheduler._schedule_running(budget, curr_loras)
    remaining_running = scheduler.running
633
    assert len(remaining_running) == 0
634
635
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
636
637
638
    assert len(output.preempted) == 0
    assert len(output.swapped_out) == 0
    # Nothing is preempted.
639
    assert output.blocks_to_swap_out == []
640
641
    # Since append_slot returns the source -> dist mapping, it should
    # applied.
642
    assert output.blocks_to_copy == [(2, 3)]
643
644


645
def test_schedule_swapped_max_loras():
646
    block_size = 4
647
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
648
649
650
651
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
652
653
    curr_loras: set[int] = set()
    blocks_to_swap_out: list[tuple[int, int]] = []
654
655
656
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
657
                                           block_size=block_size,
658
659
660
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
661
                                               lora_path="abc"))
662
        scheduler._allocate_and_set_running(seq_group)
663
        append_new_token_seq_group(60, seq_group, 1)
664
        scheduler._swap_out(seq_group, blocks_to_swap_out)
665
        scheduler._add_seq_group_to_swapped(seq_group)
666
667

    budget = create_token_budget()
668
669
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
670
671
672
    assert len(remaining_swapped) == 1
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 1
673
674
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
675
676
677
    assert len(curr_loras) == 1


678
def test_schedule_swapped_cannot_swap_in():
679
    block_size = 4
680
    scheduler = initialize_scheduler(block_size=block_size,
681
682
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
683
    curr_loras = None
684
    blocks_to_swap_out: list[tuple[int, int]] = []
685
686
687
688
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
689
        scheduler._allocate_and_set_running(seq_group)
690
        append_new_token_seq_group(60, seq_group, 1)
691
        scheduler._swap_out(seq_group, blocks_to_swap_out)
692
        scheduler._add_seq_group_to_swapped(seq_group)
693
694
695

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
696
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
697
698
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
699
700
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
701
702
703
    assert len(remaining_swapped) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
704
    assert len(output.decode_seq_groups) == 0
705
706
707
    assert len(output.prefill_seq_groups) == 0


708
def test_infeasible_swap():
709
    block_size = 4
710
    scheduler = initialize_scheduler(block_size=block_size,
711
712
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
713
    curr_loras = None
714
    blocks_to_swap_out: list[tuple[int, int]] = []
715
716
717
718
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
719
720
721
        scheduler._allocate_and_set_running(seq_group)
        append_new_token_seq_group(60, seq_group, 1)
        scheduler._swap_out(seq_group, blocks_to_swap_out)
722
        scheduler._add_seq_group_to_swapped(seq_group)
723
724
725
726
727
728

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
729
730
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
731
732
733
734
735
    assert len(remaining_swapped) == 0
    assert len(output.infeasible_seq_groups) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(output.decode_seq_groups) == 0
736
    assert len(output.prefill_seq_groups) == 0
737
738


739
def test_schedule_swapped_blocks_to_copy():
740
    block_size = 4
741
    scheduler = initialize_scheduler(block_size=block_size,
742
743
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
744
    curr_loras = None
745
746
747
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       block_size=block_size)
748
    scheduler._allocate_and_set_running(seq_group)
749
    append_new_token_seq_group(60, seq_group, 1)
750
    blocks_to_swap_out: list[tuple[int, int]] = []
751
    scheduler._swap_out(seq_group, blocks_to_swap_out)
752
    scheduler._add_seq_group_to_swapped(seq_group)
753
754
755

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
756
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
757
758

    budget = create_token_budget()
759
760
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
761
    assert len(remaining_swapped) == 0
762
763
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
764
    assert output.blocks_to_copy == [(2, 3)]
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808


def test_scheduling_budget():
    TOKEN_BUDGET = 4
    MAX_SEQS = 4
    budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
    assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
    assert budget.remaining_token_budget() == TOKEN_BUDGET

    # Verify add/subtract num batched tokens.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
    # Verify adding another seq group is no-op.
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0

    # Verify add/subtract max seqs.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
    assert budget.num_curr_seqs == 2
    # Verify adding another seq group is no-op.
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 2
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970


@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_prefix_caching_aware_prefills(enable_prefix_caching):
    """
    Test the below scenario:

    For 3 sequences, seqA, seqB, seqC, share the first block as prefix.

    The test verifies the below scenarios:
    1.  SeqA is first scheduled.
    2.  SeqB and SeqC can be prefilled together in a single schedule round
    even though there are not enough token budgets to prefill both without
    considering prefix caching.
    """

    block_size = 4
    max_num_batched_tokens = 12
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=max_num_batched_tokens,
        enable_prefix_caching=enable_prefix_caching,
    )

    seqA_tokens = list(range(8))
    num_shared_tokens = 4
    seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        12, 16))  # Shared prefix first 4.
    seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        16, 20))  # Shared prefix first 4.

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)

    # Schedule seqA prefill.
    scheduler.add_seq_group(seqA_group)
    metas, out, _ = scheduler.schedule()
    assert (len(out.scheduled_seq_groups) == 1
            and out.scheduled_seq_groups[0].seq_group == seqA_group)
    assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)

    # Schedule seqA decode.
    append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
    metas, out, _ = scheduler.schedule()

    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 1

    # Schedule seqB and seqC prefills should work with prefix caching.
    scheduler.add_seq_group(seqB_group)
    scheduler.add_seq_group(seqC_group)
    metas, out, _ = scheduler.schedule()

    if enable_prefix_caching:
        assert len(out.scheduled_seq_groups) == 2
        assert set([
            out.scheduled_seq_groups[0].seq_group,
            out.scheduled_seq_groups[1].seq_group,
        ]) == set([seqB_group, seqC_group])
        assert len(metas) == 2
        for meta in metas:
            assert meta.token_chunk_size == 8
            assert (len(meta.computed_block_nums) == num_shared_tokens //
                    block_size)  # 1 Block for the 8 tokens.
    else:
        assert len(out.scheduled_seq_groups) == 1
        assert len(metas) == 1
        assert metas[0].token_chunk_size == 8
        assert len(metas[0].computed_block_nums) == 0  # No blocks computed.


def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
):
    """
    This test verifies that we don't schedule new prefills if there's already
    a continuous prefill in progress even though the new prefills with shared
    prefix can fit in the token budget:

    - SeqA is being chunked prefill.
    - SeqB with the same prompt shouldn't be scheduled for prefill even though
    there's enough token budget to prefill the cached tokens.
    - Neither should seqC be scheduled.

    - When seqA is in decoding phase, seqB and seqC can be scheduled.
        - Entire seqB should be prefilled since it's a full prefix cache hit.
        - SeqC would be partially prefilled with the prefix shared, and the
        remaining unique tokens would be prefilled (rounded down to be
        block-size aligned).
    """

    block_size = 2
    max_num_batched_tokens = 4
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=100,
        enable_prefix_caching=True,
        enable_chunked_prefill=True,
    )

    seqA_tokens = list(range(8))
    seqB_tokens = seqA_tokens
    seqC_shared_prefix_len = 4
    seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)

    # Chunked prefill seqA.
    scheduler.add_seq_group(seqA_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # seqB should not be scheduled with ongoing prefills.
    scheduler.add_seq_group(seqB_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # both seqB and seqC can now be scheduled with seqA is over.
    # seqA is in decoding phase.
    append_new_token_seq(seqA, 999)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)
    scheduler.add_seq_group(seqC_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 3

    metas = {meta.request_id: meta for meta in metas}
    assert metas[seqA_group.request_id].token_chunk_size == 1  # Decode
    assert (metas[seqB_group.request_id].token_chunk_size == 8
            )  # Fully cached prefill
    assert (
        metas[seqC_group.request_id].token_chunk_size == 6
    ), "A partial prefix of C (4 tokens) should be prefilled, with the "
    "remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
    "then be rounded down to 2 tokens on block size, thus 6 tokens in total."