"vscode:/vscode.git/clone" did not exist on "98c89e16ff834f1c9f1c465e7342c5353e0eb627"
test_scheduler.py 38.1 KB
Newer Older
1
import time
2
from collections import deque
3
from typing import List, Set, Tuple
4
from unittest.mock import MagicMock
5

6
import pytest  # noqa
7
from torch import Use  # noqa
8

9
10
11
12
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
13
from vllm.sequence import SequenceGroup
14

15
16
17
from .utils import (append_new_token, append_new_token_seq,
                    append_new_token_seq_group, create_dummy_prompt,
                    get_sequence_groups, schedule_and_update_computed_tokens)
18
19


20
def test_scheduler_add_seq_group():
21
    block_size = 4
22
    scheduler_config = SchedulerConfig(
23
24
25
26
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
27
    )
28
    cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
29
30
31
32
33
34
35
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq group to scheduler.
    num_seq_group = 4
    for i in range(num_seq_group):
36
37
38
        _, seq_group = create_dummy_prompt(str(i),
                                           block_size,
                                           block_size=block_size)
39
40
41
42
        scheduler.add_seq_group(seq_group)
        assert scheduler.get_num_unfinished_seq_groups() == i + 1


43
def test_scheduler_abort_seq_group():
44
    block_size = 4
45
    scheduler_config = SchedulerConfig(
46
47
48
49
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
50
    )
51
52
53
54
55
56
57
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add multiple seq groups to scheduler.
    num_seq_group = 4
58
    request_ids: Set[str] = set()
59
60
61
62
63
64
65
66
67
68
69
    for i in range(num_seq_group):
        _, seq_group = create_dummy_prompt(str(i), block_size)
        scheduler.add_seq_group(seq_group)
        request_ids.add(str(i))

    # Abort all added seq groups.
    assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
    scheduler.abort_seq_group(request_ids)
    assert scheduler.get_num_unfinished_seq_groups() == 0


70
def test_scheduler_schedule_simple():
71
72
73
    block_size = 4
    num_seq_group = 4
    max_model_len = 16
74
    scheduler_config = SchedulerConfig(
75
76
77
78
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=num_seq_group,
        max_model_len=max_model_len,
79
    )
80
81
82
83
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)
84
    running: List[SequenceGroup] = []
85
86
87

    # Add seq groups to scheduler.
    for i in range(num_seq_group):
88
89
90
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
91
92
93
94
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Schedule seq groups prompts.
95
    num_tokens = block_size * num_seq_group
96
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
97
    assert set(get_sequence_groups(out)) == set(running)
98
    assert out.num_batched_tokens == num_tokens
99
100
101
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
102
    append_new_token(out, 1)
103
104

    # Schedule seq groups generation.
105
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
106
    assert set(get_sequence_groups(out)) == set(running)
107
108
109
110
    assert out.num_batched_tokens == num_seq_group
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
111
112
113
    append_new_token(out, 1)


114
def test_scheduler_prefill_prioritized():
115
116
117
118
    """Verify running batched tokens are not applied to prefill requests."""
    block_size = 4
    max_model_len = 30
    max_batched_num_tokens = 30
119
    scheduler_config = SchedulerConfig(
120
121
122
123
        "generate",
        max_num_batched_tokens=max_batched_num_tokens,
        max_num_seqs=2,
        max_model_len=max_model_len,
124
    )
125
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
126
127
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
128
129
130
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
131
    _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size)
132
133
134
135
136
137
138
    scheduler.add_seq_group(seq_group_a)

    # Schedule seq groups prompts.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_a]

    # Add a new prefill request B.
139
    _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size)
140
141
142
143
144
145
    scheduler.add_seq_group(seq_group_b)

    # Verify prefill requests are prioritized. Since max_batched_num_tokens
    # is 1, new prefill request has to be scheduled first.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_b]
146
147


148
def test_scheduler_schedule_preempt_abort():
149
150
    block_size = 4
    max_model_len = 16
151
    scheduler_config = SchedulerConfig(
152
153
154
155
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=2,
        max_model_len=max_model_len,
156
    )
157
158
159
160
161
162
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 2
    cache_config.num_gpu_blocks = 2
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
163
164
165
166
167
168
    seq_a, seq_group_a = create_dummy_prompt("1",
                                             block_size,
                                             block_size=block_size)
    seq_b, seq_group_b = create_dummy_prompt("2",
                                             block_size,
                                             block_size=block_size)
169
170
171
172
    scheduler.add_seq_group(seq_group_a)
    scheduler.add_seq_group(seq_group_b)

    # Schedule seq groups prompts.
173
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
174
    assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
175
    assert out.num_batched_tokens == block_size * 2  # seq_a and seq_b
176
177
178
179
180
181
182
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 2
    assert scheduler.get_num_unfinished_seq_groups() == 2

    # Append "generated" tokens, allowing the sequence to mark prompt tokens as
    # processed.
183
    append_new_token(out, 1)
184
185

    # Schedule seq groups generation and preempt seq group b.
186
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
187
    assert get_sequence_groups(out) == [seq_group_a]
188
189
190
191
192
    assert out.num_batched_tokens == 1
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 2
193
    assert out.preempted == 1
194
195
196

    # Abort seq group a. Re-schedule seq group b prompt with recomputation.
    scheduler.abort_seq_group("1")
197
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
198
    assert get_sequence_groups(out) == [seq_group_b]
199
    assert out.num_batched_tokens == 5  # 4 prompt + 1 generation.
200
201
202
203
204
205
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 1


206
def test_scheduler_max_seqs():
207
208
209
210
    block_size = 4
    num_seq_group = 4
    max_seq_group = 2
    max_model_len = 16
211
    scheduler_config = SchedulerConfig(
212
213
214
215
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=max_seq_group,
        max_model_len=max_model_len,
216
    )
217
218
219
220
221
222
223
224
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    all_seq_groups: List[SequenceGroup] = []
    # Add seq groups to scheduler.
    for i in range(num_seq_group):
225
226
227
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
228
229
230
231
232
233
        all_seq_groups.append(seq_group)

    # Append 1 seq group
    scheduler.add_seq_group(all_seq_groups[0])

    # Schedule seq groups prompts.
234
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
235
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
236
    append_new_token(out, 1)
237
238

    # Schedule seq groups generation.
239
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
240
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
241
    append_new_token(out, 1)
242
243
244
245
246
247
248
249

    # Append 2 more seq group
    scheduler.add_seq_group(all_seq_groups[1])
    scheduler.add_seq_group(all_seq_groups[2])

    # Schedule seq groups prompts.
    # Only 1 seq group should be scheduled since max_seq_group is 2
    # and one is prompting.
250
    _, out = schedule_and_update_computed_tokens(scheduler)
251
    assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
252
253


254
def test_scheduler_delay_factor():
255
    block_size = 4
256
    scheduler_config = SchedulerConfig(
257
258
259
260
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=16,
261
        delay_factor=0.5,
262
    )
263
264
265
266
267
268
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # schedule first prompt
269
    seq_group_meta, seq_group = create_dummy_prompt("0",
270
271
                                                    prompt_length=block_size,
                                                    block_size=block_size)
272
    scheduler.add_seq_group(seq_group)
273
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
274
    assert out.num_prefill_groups > 0
275
    assert seq_group_meta[0].request_id == '0'
276
    append_new_token(out, 1)
277
278
279

    # wait for a second before scheduling next prompt
    time.sleep(1)
280
    seq_group_meta, seq_group = create_dummy_prompt("1",
281
282
                                                    prompt_length=block_size,
                                                    block_size=block_size)
283
284
285
    scheduler.add_seq_group(seq_group)

    # second prompt should *not* be scheduled
286
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
287
    assert out.num_prefill_groups == 0
288
    assert seq_group_meta[0].request_id == '0'
289
    append_new_token(out, 1)
290
291
292

    # wait for more than 0.5 second and try again
    time.sleep(0.6)
293
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
294
    assert out.num_prefill_groups > 0
295
    assert seq_group_meta[0].request_id == '1'
296
    append_new_token(out, 1)
297
298


299
300
301
302
303
304
305
306
307
def initialize_scheduler(
    *,
    max_num_seqs=1000,
    max_token_budget=1000,
    max_model_len=1000,
    lora_config=None,
    block_size=4,
    num_cpu_blocks=8,
    num_gpu_blocks=8,
308
309
    enable_prefix_caching=False,
    enable_chunked_prefill=False,
310
311
312
):
    block_size = block_size
    scheduler_config = SchedulerConfig(
313
314
315
316
        "generate",
        max_num_batched_tokens=max_token_budget,
        max_num_seqs=max_num_seqs,
        max_model_len=max_model_len,
317
318
319
320
321
322
323
324
        enable_chunked_prefill=enable_chunked_prefill,
    )
    cache_config = CacheConfig(
        block_size,
        1.0,
        1,
        "auto",
        enable_prefix_caching=enable_prefix_caching,
325
    )
326
327
    cache_config.num_cpu_blocks = num_cpu_blocks
    cache_config.num_gpu_blocks = num_gpu_blocks
328
329
330
331
    scheduler = Scheduler(scheduler_config, cache_config, lora_config)
    return scheduler


332
def create_token_budget(token_budget: int = 10000,
333
334
335
336
337
338
339
                        max_num_seqs: int = 10000) -> SchedulingBudget:
    return SchedulingBudget(
        token_budget=token_budget,
        max_num_seqs=max_num_seqs,
    )


340
341
342
343
344
345
346
347
348
def add_token_budget(budget: SchedulingBudget,
                     num_batched_tokens: int = 0,
                     num_curr_seqs: int = 0):
    mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
    budget.add_num_batched_tokens(mock_seq_group.request_id,
                                  num_batched_tokens)
    budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)


349
def test_prefill_schedule_max_prompt_len():
350
351
352
    """
    Test prompt longer than max_prompt_len is aborted.
    """
353
    block_size = 4
354
    scheduler = initialize_scheduler(max_model_len=30, block_size=block_size)
355
356
357
    _, seq_group = create_dummy_prompt("0",
                                       prompt_length=60,
                                       block_size=block_size)
358
    scheduler.add_seq_group(seq_group)
359
    budget = create_token_budget()
360
361
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
362
363
364
365
366
367
368
    assert len(output.ignored_seq_groups) == 1
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


369
def test_prefill_schedule_token_budget():
370
371
372
    """
    Test token budget respected.
    """
373
    block_size = 4
374
    scheduler = initialize_scheduler(block_size=block_size,
375
376
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
377
378
    budget = create_token_budget(token_budget=0)
    for i in range(2):
379
380
381
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
382
        scheduler.add_seq_group(seq_group)
383
384

    # 0 token budget == nothing is scheduled.
385
386
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
387
388
389
390
391
392
393
394
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 2

    # 60 token budget == 1 request scheduled.
    budget = create_token_budget(token_budget=60)
395
396
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
397
398
399
400
401
402
403
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 60
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 1

    # Test when current_batched_tokens respected.
404
    scheduler = initialize_scheduler(block_size=block_size,
405
406
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
407
408
    budget = create_token_budget(token_budget=60)
    add_token_budget(budget, 30, 0)
409
410
411
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
412
    # Cannot schedule a prompt that doesn't fit the budget.
413
414
415
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
416
417
418
419
420
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 30
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 1
421
422
    budget = create_token_budget(token_budget=90)
    add_token_budget(budget, 30, 0)
423
424
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
425
426
427
428
429
430
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 90
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 0


431
def test_prefill_schedule_max_seqs():
432
433
434
    """
    Test max seq respected.
    """
435
    block_size = 4
436
    scheduler = initialize_scheduler(block_size=block_size,
437
438
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
439
440
    budget = create_token_budget(max_num_seqs=2)
    for i in range(3):
441
442
443
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
444
445
446
        scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
447
448
449
450
451
452
453
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1

    # Verify curr_num_seqs respected.
454
    scheduler.waiting = deque()
455
456
    budget = create_token_budget(max_num_seqs=2)
    add_token_budget(budget, 0, 2)
457
458
459
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
460
461
462
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
463
464
465
466
467
468
469
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1


470
def test_prefill_schedule_max_lora():
471
472
473
    """
    Test max lora is respected and prioritized.
    """
474
    block_size = 4
475
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
476
477
478
479
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
480
    budget = create_token_budget(token_budget=120)
481
    curr_loras: Set[int] = set()
482
483
484
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
485
                                           block_size=block_size,
486
487
488
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
489
                                               lora_path="abc"))
490
        scheduler.add_seq_group(seq_group)
491
492
493
494
495
496
    # Add two more requests to verify lora is prioritized.
    # 0: Lora, 1: Lora, 2: regular, 3: regular
    # In the first iteration, index 0, 2 is scheduled.
    # If a request is not scheduled because it hits max lora, it is
    # prioritized. Verify that.
    for i in range(2, 4):
497
498
499
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
500
        scheduler.add_seq_group(seq_group)
501
    # Schedule 2 requests (0 and 2)
502
503
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
504
505
506
507
508
509
510
511
512
513
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 2
    assert len(curr_loras) == 1
    # The second lora request is scheduled next as FCFS policy.
    # Reset curr_loras so that it can be scheduled.
    curr_loras = set()
    budget = create_token_budget(token_budget=60)
514
515
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
516
517
518
519
520
521
522
    assert len(output.seq_groups) == 1
    assert output.seq_groups[0].seq_group.request_id == "1"
    assert len(remaining_waiting) == 1
    assert len(curr_loras) == 1
    assert budget.num_batched_tokens == 60


523
def test_prefill_schedule_no_block_manager_capacity():
524
525
526
    """
    Test sequence cannot be scheduled due to block manager has no capacity.
    """
527
    block_size = 4
528
    scheduler = initialize_scheduler(block_size=block_size,
529
530
                                     num_gpu_blocks=128,
                                     num_cpu_blocks=128)
531
532
    budget = create_token_budget()
    for i in range(3):
533
534
535
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
536
        scheduler.add_seq_group(seq_group)
537
538
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
539
540
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
541
542
543
544
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
545
    assert len(remaining_waiting) == 3
546
547
548
549

    scheduler = initialize_scheduler()
    budget = create_token_budget()
    for i in range(3):
550
551
552
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
553
        scheduler.add_seq_group(seq_group)
554
555
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
556
557
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
558
559
560
561
562
563
564
    assert len(output.ignored_seq_groups) == 3
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


565
def test_decode_schedule_preempted():
566
567
568
    """
    Test decodes cannot be scheduled and preempted.
    """
569
    block_size = 4
570
    scheduler = initialize_scheduler(block_size=block_size,
571
572
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
573
574
    curr_loras = None
    for i in range(3):
575
576
577
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
578
        scheduler._allocate_and_set_running(seq_group)
579
        append_new_token_seq_group(60, seq_group, 1)
580
        scheduler._add_seq_group_to_running(seq_group)
581
582
583
584
585
586
587
588
589
590
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group(seq_group, num_lookahead_slots):
        return seq_group.request_id != "1"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group)

    # 1 cannot be scheduled, and the lowest priority (request 2)
    # should be preempted. 1 will also be preempted.
591
    budget = create_token_budget()
592
593
    output = scheduler._schedule_running(budget, curr_loras)
    remainig_running = scheduler.running
594
    assert len(remainig_running) == 0
595
596
597
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
    assert output.decode_seq_groups[0].seq_group.request_id == "0"
598
599
600
    assert len(output.preempted) == 2
    # Verify budgets are updated.
    assert budget.num_batched_tokens == 1
601
602
    # NOTE: When enable_chunk is False, num_seqs budget is not updated.
    # assert budget.num_curr_seqs == 1
603
    # Both should be preempted, not swapped.
604
    assert output.blocks_to_swap_out == []
605
    # Nothing is copied.
606
    assert output.blocks_to_copy == []
607
608


609
def test_schedule_decode_blocks_to_copy_update():
610
611
612
    """
    Verify blocks_to_copy is updated.
    """
613
    block_size = 4
614
    scheduler = initialize_scheduler(block_size=4,
615
616
617
618
619
620
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       best_of=2,
                                       block_size=block_size)
621
    curr_loras = None
622
    scheduler._allocate_and_set_running(seq_group)
623
    append_new_token_seq_group(60, seq_group, 1)
624
    scheduler._add_seq_group_to_running(seq_group)
625
626
627

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
628
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
629
630

    budget = create_token_budget()
631
632
    output = scheduler._schedule_running(budget, curr_loras)
    remaining_running = scheduler.running
633
    assert len(remaining_running) == 0
634
635
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
636
637
638
    assert len(output.preempted) == 0
    assert len(output.swapped_out) == 0
    # Nothing is preempted.
639
    assert output.blocks_to_swap_out == []
640
641
    # Since append_slot returns the source -> dist mapping, it should
    # applied.
642
    assert output.blocks_to_copy == [(2, 3)]
643
644


645
def test_schedule_swapped_max_loras():
646
    block_size = 4
647
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
648
649
650
651
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
652
653
    curr_loras: Set[int] = set()
    blocks_to_swap_out: List[Tuple[int, int]] = []
654
655
656
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
657
                                           block_size=block_size,
658
659
660
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
661
                                               lora_path="abc"))
662
        scheduler._allocate_and_set_running(seq_group)
663
        append_new_token_seq_group(60, seq_group, 1)
664
        scheduler._swap_out(seq_group, blocks_to_swap_out)
665
        scheduler._add_seq_group_to_swapped(seq_group)
666
667

    budget = create_token_budget()
668
669
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
670
671
672
    assert len(remaining_swapped) == 1
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 1
673
674
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
675
676
677
    assert len(curr_loras) == 1


678
def test_schedule_swapped_cannot_swap_in():
679
    block_size = 4
680
    scheduler = initialize_scheduler(block_size=block_size,
681
682
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
683
    curr_loras = None
684
    blocks_to_swap_out: List[Tuple[int, int]] = []
685
686
687
688
689
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           best_of=2,
                                           block_size=block_size)
690
        scheduler._allocate_and_set_running(seq_group)
691
        append_new_token_seq_group(60, seq_group, 1)
692
        scheduler._swap_out(seq_group, blocks_to_swap_out)
693
        scheduler._add_seq_group_to_swapped(seq_group)
694
695
696

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
697
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
698
699
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
700
701
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
702
703
704
    assert len(remaining_swapped) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
705
    assert len(output.decode_seq_groups) == 0
706
707
708
    assert len(output.prefill_seq_groups) == 0


709
def test_infeasible_swap():
710
    block_size = 4
711
    scheduler = initialize_scheduler(block_size=block_size,
712
713
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
714
    curr_loras = None
715
    blocks_to_swap_out: List[Tuple[int, int]] = []
716
717
718
719
720
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           best_of=2,
                                           block_size=block_size)
721
722
723
        scheduler._allocate_and_set_running(seq_group)
        append_new_token_seq_group(60, seq_group, 1)
        scheduler._swap_out(seq_group, blocks_to_swap_out)
724
        scheduler._add_seq_group_to_swapped(seq_group)
725
726
727
728
729
730

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
731
732
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
733
734
735
736
737
    assert len(remaining_swapped) == 0
    assert len(output.infeasible_seq_groups) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(output.decode_seq_groups) == 0
738
    assert len(output.prefill_seq_groups) == 0
739
740


741
def test_schedule_swapped_blocks_to_copy():
742
    block_size = 4
743
    scheduler = initialize_scheduler(block_size=block_size,
744
745
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
746
    curr_loras = None
747
748
749
750
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       best_of=2,
                                       block_size=block_size)
751
    scheduler._allocate_and_set_running(seq_group)
752
    append_new_token_seq_group(60, seq_group, 1)
753
    blocks_to_swap_out: List[Tuple[int, int]] = []
754
    scheduler._swap_out(seq_group, blocks_to_swap_out)
755
    scheduler._add_seq_group_to_swapped(seq_group)
756
757
758

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
759
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
760
761

    budget = create_token_budget()
762
763
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
764
    assert len(remaining_swapped) == 0
765
766
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
767
    assert output.blocks_to_copy == [(2, 3)]
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811


def test_scheduling_budget():
    TOKEN_BUDGET = 4
    MAX_SEQS = 4
    budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
    assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
    assert budget.remaining_token_budget() == TOKEN_BUDGET

    # Verify add/subtract num batched tokens.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
    # Verify adding another seq group is no-op.
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0

    # Verify add/subtract max seqs.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
    assert budget.num_curr_seqs == 2
    # Verify adding another seq group is no-op.
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 2
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973


@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_prefix_caching_aware_prefills(enable_prefix_caching):
    """
    Test the below scenario:

    For 3 sequences, seqA, seqB, seqC, share the first block as prefix.

    The test verifies the below scenarios:
    1.  SeqA is first scheduled.
    2.  SeqB and SeqC can be prefilled together in a single schedule round
    even though there are not enough token budgets to prefill both without
    considering prefix caching.
    """

    block_size = 4
    max_num_batched_tokens = 12
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=max_num_batched_tokens,
        enable_prefix_caching=enable_prefix_caching,
    )

    seqA_tokens = list(range(8))
    num_shared_tokens = 4
    seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        12, 16))  # Shared prefix first 4.
    seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        16, 20))  # Shared prefix first 4.

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)

    # Schedule seqA prefill.
    scheduler.add_seq_group(seqA_group)
    metas, out, _ = scheduler.schedule()
    assert (len(out.scheduled_seq_groups) == 1
            and out.scheduled_seq_groups[0].seq_group == seqA_group)
    assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)

    # Schedule seqA decode.
    append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
    metas, out, _ = scheduler.schedule()

    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 1

    # Schedule seqB and seqC prefills should work with prefix caching.
    scheduler.add_seq_group(seqB_group)
    scheduler.add_seq_group(seqC_group)
    metas, out, _ = scheduler.schedule()

    if enable_prefix_caching:
        assert len(out.scheduled_seq_groups) == 2
        assert set([
            out.scheduled_seq_groups[0].seq_group,
            out.scheduled_seq_groups[1].seq_group,
        ]) == set([seqB_group, seqC_group])
        assert len(metas) == 2
        for meta in metas:
            assert meta.token_chunk_size == 8
            assert (len(meta.computed_block_nums) == num_shared_tokens //
                    block_size)  # 1 Block for the 8 tokens.
    else:
        assert len(out.scheduled_seq_groups) == 1
        assert len(metas) == 1
        assert metas[0].token_chunk_size == 8
        assert len(metas[0].computed_block_nums) == 0  # No blocks computed.


def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
):
    """
    This test verifies that we don't schedule new prefills if there's already
    a continuous prefill in progress even though the new prefills with shared
    prefix can fit in the token budget:

    - SeqA is being chunked prefill.
    - SeqB with the same prompt shouldn't be scheduled for prefill even though
    there's enough token budget to prefill the cached tokens.
    - Neither should seqC be scheduled.

    - When seqA is in decoding phase, seqB and seqC can be scheduled.
        - Entire seqB should be prefilled since it's a full prefix cache hit.
        - SeqC would be partially prefilled with the prefix shared, and the
        remaining unique tokens would be prefilled (rounded down to be
        block-size aligned).
    """

    block_size = 2
    max_num_batched_tokens = 4
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=100,
        enable_prefix_caching=True,
        enable_chunked_prefill=True,
    )

    seqA_tokens = list(range(8))
    seqB_tokens = seqA_tokens
    seqC_shared_prefix_len = 4
    seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)

    # Chunked prefill seqA.
    scheduler.add_seq_group(seqA_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # seqB should not be scheduled with ongoing prefills.
    scheduler.add_seq_group(seqB_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # both seqB and seqC can now be scheduled with seqA is over.
    # seqA is in decoding phase.
    append_new_token_seq(seqA, 999)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)
    scheduler.add_seq_group(seqC_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 3

    metas = {meta.request_id: meta for meta in metas}
    assert metas[seqA_group.request_id].token_chunk_size == 1  # Decode
    assert (metas[seqB_group.request_id].token_chunk_size == 8
            )  # Fully cached prefill
    assert (
        metas[seqC_group.request_id].token_chunk_size == 6
    ), "A partial prefix of C (4 tokens) should be prefilled, with the "
    "remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
    "then be rounded down to 2 tokens on block size, thus 6 tokens in total."