test_scheduler.py 38.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import deque
5
from typing import List, Set, Tuple
6
from unittest.mock import MagicMock
7

8
import pytest  # noqa
9
from torch import Use  # noqa
10

11
12
13
14
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
15
from vllm.sequence import SequenceGroup
16

17
18
19
from .utils import (append_new_token, append_new_token_seq,
                    append_new_token_seq_group, create_dummy_prompt,
                    get_sequence_groups, schedule_and_update_computed_tokens)
20
21


22
def test_scheduler_add_seq_group():
23
    block_size = 4
24
    scheduler_config = SchedulerConfig(
25
26
27
28
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
29
    )
30
    cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
31
32
33
34
35
36
37
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq group to scheduler.
    num_seq_group = 4
    for i in range(num_seq_group):
38
39
40
        _, seq_group = create_dummy_prompt(str(i),
                                           block_size,
                                           block_size=block_size)
41
42
43
44
        scheduler.add_seq_group(seq_group)
        assert scheduler.get_num_unfinished_seq_groups() == i + 1


45
def test_scheduler_abort_seq_group():
46
    block_size = 4
47
    scheduler_config = SchedulerConfig(
48
49
50
51
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
52
    )
53
54
55
56
57
58
59
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add multiple seq groups to scheduler.
    num_seq_group = 4
60
    request_ids: Set[str] = set()
61
62
63
64
65
66
67
68
69
70
71
    for i in range(num_seq_group):
        _, seq_group = create_dummy_prompt(str(i), block_size)
        scheduler.add_seq_group(seq_group)
        request_ids.add(str(i))

    # Abort all added seq groups.
    assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
    scheduler.abort_seq_group(request_ids)
    assert scheduler.get_num_unfinished_seq_groups() == 0


72
def test_scheduler_schedule_simple():
73
74
75
    block_size = 4
    num_seq_group = 4
    max_model_len = 16
76
    scheduler_config = SchedulerConfig(
77
78
79
80
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=num_seq_group,
        max_model_len=max_model_len,
81
    )
82
83
84
85
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)
86
    running: List[SequenceGroup] = []
87
88
89

    # Add seq groups to scheduler.
    for i in range(num_seq_group):
90
91
92
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
93
94
95
96
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Schedule seq groups prompts.
97
    num_tokens = block_size * num_seq_group
98
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
99
    assert set(get_sequence_groups(out)) == set(running)
100
    assert out.num_batched_tokens == num_tokens
101
102
103
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
104
    append_new_token(out, 1)
105
106

    # Schedule seq groups generation.
107
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
108
    assert set(get_sequence_groups(out)) == set(running)
109
110
111
112
    assert out.num_batched_tokens == num_seq_group
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
113
114
115
    append_new_token(out, 1)


116
def test_scheduler_prefill_prioritized():
117
118
119
120
    """Verify running batched tokens are not applied to prefill requests."""
    block_size = 4
    max_model_len = 30
    max_batched_num_tokens = 30
121
    scheduler_config = SchedulerConfig(
122
123
124
125
        "generate",
        max_num_batched_tokens=max_batched_num_tokens,
        max_num_seqs=2,
        max_model_len=max_model_len,
126
    )
127
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
128
129
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
130
131
132
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
133
    _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size)
134
135
136
137
138
139
140
    scheduler.add_seq_group(seq_group_a)

    # Schedule seq groups prompts.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_a]

    # Add a new prefill request B.
141
    _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size)
142
143
144
145
146
147
    scheduler.add_seq_group(seq_group_b)

    # Verify prefill requests are prioritized. Since max_batched_num_tokens
    # is 1, new prefill request has to be scheduled first.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_b]
148
149


150
def test_scheduler_schedule_preempt_abort():
151
152
    block_size = 4
    max_model_len = 16
153
    scheduler_config = SchedulerConfig(
154
155
156
157
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=2,
        max_model_len=max_model_len,
158
    )
159
160
161
162
163
164
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 2
    cache_config.num_gpu_blocks = 2
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
165
166
167
168
169
170
    seq_a, seq_group_a = create_dummy_prompt("1",
                                             block_size,
                                             block_size=block_size)
    seq_b, seq_group_b = create_dummy_prompt("2",
                                             block_size,
                                             block_size=block_size)
171
172
173
174
    scheduler.add_seq_group(seq_group_a)
    scheduler.add_seq_group(seq_group_b)

    # Schedule seq groups prompts.
175
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
176
    assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
177
    assert out.num_batched_tokens == block_size * 2  # seq_a and seq_b
178
179
180
181
182
183
184
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 2
    assert scheduler.get_num_unfinished_seq_groups() == 2

    # Append "generated" tokens, allowing the sequence to mark prompt tokens as
    # processed.
185
    append_new_token(out, 1)
186
187

    # Schedule seq groups generation and preempt seq group b.
188
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
189
    assert get_sequence_groups(out) == [seq_group_a]
190
191
192
193
194
    assert out.num_batched_tokens == 1
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 2
195
    assert out.preempted == 1
196
197
198

    # Abort seq group a. Re-schedule seq group b prompt with recomputation.
    scheduler.abort_seq_group("1")
199
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
200
    assert get_sequence_groups(out) == [seq_group_b]
201
    assert out.num_batched_tokens == 5  # 4 prompt + 1 generation.
202
203
204
205
206
207
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 1


208
def test_scheduler_max_seqs():
209
210
211
212
    block_size = 4
    num_seq_group = 4
    max_seq_group = 2
    max_model_len = 16
213
    scheduler_config = SchedulerConfig(
214
215
216
217
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=max_seq_group,
        max_model_len=max_model_len,
218
    )
219
220
221
222
223
224
225
226
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    all_seq_groups: List[SequenceGroup] = []
    # Add seq groups to scheduler.
    for i in range(num_seq_group):
227
228
229
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
230
231
232
233
234
235
        all_seq_groups.append(seq_group)

    # Append 1 seq group
    scheduler.add_seq_group(all_seq_groups[0])

    # Schedule seq groups prompts.
236
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
237
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
238
    append_new_token(out, 1)
239
240

    # Schedule seq groups generation.
241
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
242
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
243
    append_new_token(out, 1)
244
245
246
247
248
249
250
251

    # Append 2 more seq group
    scheduler.add_seq_group(all_seq_groups[1])
    scheduler.add_seq_group(all_seq_groups[2])

    # Schedule seq groups prompts.
    # Only 1 seq group should be scheduled since max_seq_group is 2
    # and one is prompting.
252
    _, out = schedule_and_update_computed_tokens(scheduler)
253
    assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
254
255


256
def test_scheduler_delay_factor():
257
    block_size = 4
258
    scheduler_config = SchedulerConfig(
259
260
261
262
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=16,
263
        delay_factor=0.5,
264
    )
265
266
267
268
269
270
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # schedule first prompt
271
    seq_group_meta, seq_group = create_dummy_prompt("0",
272
273
                                                    prompt_length=block_size,
                                                    block_size=block_size)
274
    scheduler.add_seq_group(seq_group)
275
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
276
    assert out.num_prefill_groups > 0
277
    assert seq_group_meta[0].request_id == '0'
278
    append_new_token(out, 1)
279
280
281

    # wait for a second before scheduling next prompt
    time.sleep(1)
282
    seq_group_meta, seq_group = create_dummy_prompt("1",
283
284
                                                    prompt_length=block_size,
                                                    block_size=block_size)
285
286
287
    scheduler.add_seq_group(seq_group)

    # second prompt should *not* be scheduled
288
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
289
    assert out.num_prefill_groups == 0
290
    assert seq_group_meta[0].request_id == '0'
291
    append_new_token(out, 1)
292
293
294

    # wait for more than 0.5 second and try again
    time.sleep(0.6)
295
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
296
    assert out.num_prefill_groups > 0
297
    assert seq_group_meta[0].request_id == '1'
298
    append_new_token(out, 1)
299
300


301
302
303
304
305
306
307
308
309
def initialize_scheduler(
    *,
    max_num_seqs=1000,
    max_token_budget=1000,
    max_model_len=1000,
    lora_config=None,
    block_size=4,
    num_cpu_blocks=8,
    num_gpu_blocks=8,
310
311
    enable_prefix_caching=False,
    enable_chunked_prefill=False,
312
313
314
):
    block_size = block_size
    scheduler_config = SchedulerConfig(
315
316
317
318
        "generate",
        max_num_batched_tokens=max_token_budget,
        max_num_seqs=max_num_seqs,
        max_model_len=max_model_len,
319
320
321
322
323
324
325
326
        enable_chunked_prefill=enable_chunked_prefill,
    )
    cache_config = CacheConfig(
        block_size,
        1.0,
        1,
        "auto",
        enable_prefix_caching=enable_prefix_caching,
327
    )
328
329
    cache_config.num_cpu_blocks = num_cpu_blocks
    cache_config.num_gpu_blocks = num_gpu_blocks
330
331
332
333
    scheduler = Scheduler(scheduler_config, cache_config, lora_config)
    return scheduler


334
def create_token_budget(token_budget: int = 10000,
335
336
337
338
339
340
341
                        max_num_seqs: int = 10000) -> SchedulingBudget:
    return SchedulingBudget(
        token_budget=token_budget,
        max_num_seqs=max_num_seqs,
    )


342
343
344
345
346
347
348
349
350
def add_token_budget(budget: SchedulingBudget,
                     num_batched_tokens: int = 0,
                     num_curr_seqs: int = 0):
    mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
    budget.add_num_batched_tokens(mock_seq_group.request_id,
                                  num_batched_tokens)
    budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)


351
def test_prefill_schedule_max_prompt_len():
352
353
354
    """
    Test prompt longer than max_prompt_len is aborted.
    """
355
    block_size = 4
356
    scheduler = initialize_scheduler(max_model_len=30, block_size=block_size)
357
358
359
    _, seq_group = create_dummy_prompt("0",
                                       prompt_length=60,
                                       block_size=block_size)
360
    scheduler.add_seq_group(seq_group)
361
    budget = create_token_budget()
362
363
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
364
365
366
367
368
369
370
    assert len(output.ignored_seq_groups) == 1
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


371
def test_prefill_schedule_token_budget():
372
373
374
    """
    Test token budget respected.
    """
375
    block_size = 4
376
    scheduler = initialize_scheduler(block_size=block_size,
377
378
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
379
380
    budget = create_token_budget(token_budget=0)
    for i in range(2):
381
382
383
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
384
        scheduler.add_seq_group(seq_group)
385
386

    # 0 token budget == nothing is scheduled.
387
388
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
389
390
391
392
393
394
395
396
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 2

    # 60 token budget == 1 request scheduled.
    budget = create_token_budget(token_budget=60)
397
398
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
399
400
401
402
403
404
405
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 60
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 1

    # Test when current_batched_tokens respected.
406
    scheduler = initialize_scheduler(block_size=block_size,
407
408
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
409
410
    budget = create_token_budget(token_budget=60)
    add_token_budget(budget, 30, 0)
411
412
413
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
414
    # Cannot schedule a prompt that doesn't fit the budget.
415
416
417
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
418
419
420
421
422
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 30
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 1
423
424
    budget = create_token_budget(token_budget=90)
    add_token_budget(budget, 30, 0)
425
426
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
427
428
429
430
431
432
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 90
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 0


433
def test_prefill_schedule_max_seqs():
434
435
436
    """
    Test max seq respected.
    """
437
    block_size = 4
438
    scheduler = initialize_scheduler(block_size=block_size,
439
440
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
441
442
    budget = create_token_budget(max_num_seqs=2)
    for i in range(3):
443
444
445
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
446
447
448
        scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
449
450
451
452
453
454
455
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1

    # Verify curr_num_seqs respected.
456
    scheduler.waiting = deque()
457
458
    budget = create_token_budget(max_num_seqs=2)
    add_token_budget(budget, 0, 2)
459
460
461
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
462
463
464
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
465
466
467
468
469
470
471
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1


472
def test_prefill_schedule_max_lora():
473
474
475
    """
    Test max lora is respected and prioritized.
    """
476
    block_size = 4
477
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
478
479
480
481
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
482
    budget = create_token_budget(token_budget=120)
483
    curr_loras: Set[int] = set()
484
485
486
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
487
                                           block_size=block_size,
488
489
490
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
491
                                               lora_path="abc"))
492
        scheduler.add_seq_group(seq_group)
493
494
495
496
497
498
    # Add two more requests to verify lora is prioritized.
    # 0: Lora, 1: Lora, 2: regular, 3: regular
    # In the first iteration, index 0, 2 is scheduled.
    # If a request is not scheduled because it hits max lora, it is
    # prioritized. Verify that.
    for i in range(2, 4):
499
500
501
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
502
        scheduler.add_seq_group(seq_group)
503
    # Schedule 2 requests (0 and 2)
504
505
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
506
507
508
509
510
511
512
513
514
515
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 2
    assert len(curr_loras) == 1
    # The second lora request is scheduled next as FCFS policy.
    # Reset curr_loras so that it can be scheduled.
    curr_loras = set()
    budget = create_token_budget(token_budget=60)
516
517
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
518
519
520
521
522
523
524
    assert len(output.seq_groups) == 1
    assert output.seq_groups[0].seq_group.request_id == "1"
    assert len(remaining_waiting) == 1
    assert len(curr_loras) == 1
    assert budget.num_batched_tokens == 60


525
def test_prefill_schedule_no_block_manager_capacity():
526
527
528
    """
    Test sequence cannot be scheduled due to block manager has no capacity.
    """
529
    block_size = 4
530
    scheduler = initialize_scheduler(block_size=block_size,
531
532
                                     num_gpu_blocks=128,
                                     num_cpu_blocks=128)
533
534
    budget = create_token_budget()
    for i in range(3):
535
536
537
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
538
        scheduler.add_seq_group(seq_group)
539
540
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
541
542
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
543
544
545
546
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
547
    assert len(remaining_waiting) == 3
548
549
550
551

    scheduler = initialize_scheduler()
    budget = create_token_budget()
    for i in range(3):
552
553
554
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
555
        scheduler.add_seq_group(seq_group)
556
557
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
558
559
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
560
561
562
563
564
565
566
    assert len(output.ignored_seq_groups) == 3
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


567
def test_decode_schedule_preempted():
568
569
570
    """
    Test decodes cannot be scheduled and preempted.
    """
571
    block_size = 4
572
    scheduler = initialize_scheduler(block_size=block_size,
573
574
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
575
576
    curr_loras = None
    for i in range(3):
577
578
579
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
580
        scheduler._allocate_and_set_running(seq_group)
581
        append_new_token_seq_group(60, seq_group, 1)
582
        scheduler._add_seq_group_to_running(seq_group)
583
584
585
586
587
588
589
590
591
592
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group(seq_group, num_lookahead_slots):
        return seq_group.request_id != "1"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group)

    # 1 cannot be scheduled, and the lowest priority (request 2)
    # should be preempted. 1 will also be preempted.
593
    budget = create_token_budget()
594
595
    output = scheduler._schedule_running(budget, curr_loras)
    remainig_running = scheduler.running
596
    assert len(remainig_running) == 0
597
598
599
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
    assert output.decode_seq_groups[0].seq_group.request_id == "0"
600
601
602
    assert len(output.preempted) == 2
    # Verify budgets are updated.
    assert budget.num_batched_tokens == 1
603
604
    # NOTE: When enable_chunk is False, num_seqs budget is not updated.
    # assert budget.num_curr_seqs == 1
605
    # Both should be preempted, not swapped.
606
    assert output.blocks_to_swap_out == []
607
    # Nothing is copied.
608
    assert output.blocks_to_copy == []
609
610


611
def test_schedule_decode_blocks_to_copy_update():
612
613
614
    """
    Verify blocks_to_copy is updated.
    """
615
    block_size = 4
616
    scheduler = initialize_scheduler(block_size=4,
617
618
619
620
621
622
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       best_of=2,
                                       block_size=block_size)
623
    curr_loras = None
624
    scheduler._allocate_and_set_running(seq_group)
625
    append_new_token_seq_group(60, seq_group, 1)
626
    scheduler._add_seq_group_to_running(seq_group)
627
628
629

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
630
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
631
632

    budget = create_token_budget()
633
634
    output = scheduler._schedule_running(budget, curr_loras)
    remaining_running = scheduler.running
635
    assert len(remaining_running) == 0
636
637
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
638
639
640
    assert len(output.preempted) == 0
    assert len(output.swapped_out) == 0
    # Nothing is preempted.
641
    assert output.blocks_to_swap_out == []
642
643
    # Since append_slot returns the source -> dist mapping, it should
    # applied.
644
    assert output.blocks_to_copy == [(2, 3)]
645
646


647
def test_schedule_swapped_max_loras():
648
    block_size = 4
649
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
650
651
652
653
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
654
655
    curr_loras: Set[int] = set()
    blocks_to_swap_out: List[Tuple[int, int]] = []
656
657
658
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
659
                                           block_size=block_size,
660
661
662
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
663
                                               lora_path="abc"))
664
        scheduler._allocate_and_set_running(seq_group)
665
        append_new_token_seq_group(60, seq_group, 1)
666
        scheduler._swap_out(seq_group, blocks_to_swap_out)
667
        scheduler._add_seq_group_to_swapped(seq_group)
668
669

    budget = create_token_budget()
670
671
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
672
673
674
    assert len(remaining_swapped) == 1
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 1
675
676
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
677
678
679
    assert len(curr_loras) == 1


680
def test_schedule_swapped_cannot_swap_in():
681
    block_size = 4
682
    scheduler = initialize_scheduler(block_size=block_size,
683
684
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
685
    curr_loras = None
686
    blocks_to_swap_out: List[Tuple[int, int]] = []
687
688
689
690
691
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           best_of=2,
                                           block_size=block_size)
692
        scheduler._allocate_and_set_running(seq_group)
693
        append_new_token_seq_group(60, seq_group, 1)
694
        scheduler._swap_out(seq_group, blocks_to_swap_out)
695
        scheduler._add_seq_group_to_swapped(seq_group)
696
697
698

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
699
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
700
701
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
702
703
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
704
705
706
    assert len(remaining_swapped) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
707
    assert len(output.decode_seq_groups) == 0
708
709
710
    assert len(output.prefill_seq_groups) == 0


711
def test_infeasible_swap():
712
    block_size = 4
713
    scheduler = initialize_scheduler(block_size=block_size,
714
715
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
716
    curr_loras = None
717
    blocks_to_swap_out: List[Tuple[int, int]] = []
718
719
720
721
722
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           best_of=2,
                                           block_size=block_size)
723
724
725
        scheduler._allocate_and_set_running(seq_group)
        append_new_token_seq_group(60, seq_group, 1)
        scheduler._swap_out(seq_group, blocks_to_swap_out)
726
        scheduler._add_seq_group_to_swapped(seq_group)
727
728
729
730
731
732

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
733
734
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
735
736
737
738
739
    assert len(remaining_swapped) == 0
    assert len(output.infeasible_seq_groups) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(output.decode_seq_groups) == 0
740
    assert len(output.prefill_seq_groups) == 0
741
742


743
def test_schedule_swapped_blocks_to_copy():
744
    block_size = 4
745
    scheduler = initialize_scheduler(block_size=block_size,
746
747
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
748
    curr_loras = None
749
750
751
752
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       best_of=2,
                                       block_size=block_size)
753
    scheduler._allocate_and_set_running(seq_group)
754
    append_new_token_seq_group(60, seq_group, 1)
755
    blocks_to_swap_out: List[Tuple[int, int]] = []
756
    scheduler._swap_out(seq_group, blocks_to_swap_out)
757
    scheduler._add_seq_group_to_swapped(seq_group)
758
759
760

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
761
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
762
763

    budget = create_token_budget()
764
765
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
766
    assert len(remaining_swapped) == 0
767
768
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
769
    assert output.blocks_to_copy == [(2, 3)]
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813


def test_scheduling_budget():
    TOKEN_BUDGET = 4
    MAX_SEQS = 4
    budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
    assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
    assert budget.remaining_token_budget() == TOKEN_BUDGET

    # Verify add/subtract num batched tokens.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
    # Verify adding another seq group is no-op.
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0

    # Verify add/subtract max seqs.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
    assert budget.num_curr_seqs == 2
    # Verify adding another seq group is no-op.
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 2
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975


@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_prefix_caching_aware_prefills(enable_prefix_caching):
    """
    Test the below scenario:

    For 3 sequences, seqA, seqB, seqC, share the first block as prefix.

    The test verifies the below scenarios:
    1.  SeqA is first scheduled.
    2.  SeqB and SeqC can be prefilled together in a single schedule round
    even though there are not enough token budgets to prefill both without
    considering prefix caching.
    """

    block_size = 4
    max_num_batched_tokens = 12
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=max_num_batched_tokens,
        enable_prefix_caching=enable_prefix_caching,
    )

    seqA_tokens = list(range(8))
    num_shared_tokens = 4
    seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        12, 16))  # Shared prefix first 4.
    seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        16, 20))  # Shared prefix first 4.

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)

    # Schedule seqA prefill.
    scheduler.add_seq_group(seqA_group)
    metas, out, _ = scheduler.schedule()
    assert (len(out.scheduled_seq_groups) == 1
            and out.scheduled_seq_groups[0].seq_group == seqA_group)
    assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)

    # Schedule seqA decode.
    append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
    metas, out, _ = scheduler.schedule()

    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 1

    # Schedule seqB and seqC prefills should work with prefix caching.
    scheduler.add_seq_group(seqB_group)
    scheduler.add_seq_group(seqC_group)
    metas, out, _ = scheduler.schedule()

    if enable_prefix_caching:
        assert len(out.scheduled_seq_groups) == 2
        assert set([
            out.scheduled_seq_groups[0].seq_group,
            out.scheduled_seq_groups[1].seq_group,
        ]) == set([seqB_group, seqC_group])
        assert len(metas) == 2
        for meta in metas:
            assert meta.token_chunk_size == 8
            assert (len(meta.computed_block_nums) == num_shared_tokens //
                    block_size)  # 1 Block for the 8 tokens.
    else:
        assert len(out.scheduled_seq_groups) == 1
        assert len(metas) == 1
        assert metas[0].token_chunk_size == 8
        assert len(metas[0].computed_block_nums) == 0  # No blocks computed.


def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
):
    """
    This test verifies that we don't schedule new prefills if there's already
    a continuous prefill in progress even though the new prefills with shared
    prefix can fit in the token budget:

    - SeqA is being chunked prefill.
    - SeqB with the same prompt shouldn't be scheduled for prefill even though
    there's enough token budget to prefill the cached tokens.
    - Neither should seqC be scheduled.

    - When seqA is in decoding phase, seqB and seqC can be scheduled.
        - Entire seqB should be prefilled since it's a full prefix cache hit.
        - SeqC would be partially prefilled with the prefix shared, and the
        remaining unique tokens would be prefilled (rounded down to be
        block-size aligned).
    """

    block_size = 2
    max_num_batched_tokens = 4
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=100,
        enable_prefix_caching=True,
        enable_chunked_prefill=True,
    )

    seqA_tokens = list(range(8))
    seqB_tokens = seqA_tokens
    seqC_shared_prefix_len = 4
    seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)

    # Chunked prefill seqA.
    scheduler.add_seq_group(seqA_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # seqB should not be scheduled with ongoing prefills.
    scheduler.add_seq_group(seqB_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # both seqB and seqC can now be scheduled with seqA is over.
    # seqA is in decoding phase.
    append_new_token_seq(seqA, 999)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)
    scheduler.add_seq_group(seqC_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 3

    metas = {meta.request_id: meta for meta in metas}
    assert metas[seqA_group.request_id].token_chunk_size == 1  # Decode
    assert (metas[seqB_group.request_id].token_chunk_size == 8
            )  # Fully cached prefill
    assert (
        metas[seqC_group.request_id].token_chunk_size == 6
    ), "A partial prefix of C (4 tokens) should be prefilled, with the "
    "remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
    "then be rounded down to 2 tokens on block size, thus 6 tokens in total."