test_scheduler.py 40.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from collections import deque
6
from typing import Optional
7
from unittest.mock import MagicMock
8

9
import pytest  # noqa
10
import torch
11
from torch import Use  # noqa
12

13
14
15
16
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
17
from vllm.sequence import SequenceGroup, SequenceStatus
18

19
20
21
from .utils import (append_new_token, append_new_token_seq,
                    append_new_token_seq_group, create_dummy_prompt,
                    get_sequence_groups, schedule_and_update_computed_tokens)
22
23


24
def test_scheduler_add_seq_group():
25
    block_size = 4
26
    scheduler_config = SchedulerConfig(
27
28
29
30
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
31
    )
32
    cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
33
34
35
36
37
38
39
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq group to scheduler.
    num_seq_group = 4
    for i in range(num_seq_group):
40
41
42
        _, seq_group = create_dummy_prompt(str(i),
                                           block_size,
                                           block_size=block_size)
43
44
45
46
        scheduler.add_seq_group(seq_group)
        assert scheduler.get_num_unfinished_seq_groups() == i + 1


47
def test_scheduler_abort_seq_group():
48
    block_size = 4
49
    scheduler_config = SchedulerConfig(
50
51
52
53
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
54
    )
55
56
57
58
59
60
61
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add multiple seq groups to scheduler.
    num_seq_group = 4
62
    request_ids: set[str] = set()
63
64
65
66
67
68
69
70
71
72
73
    for i in range(num_seq_group):
        _, seq_group = create_dummy_prompt(str(i), block_size)
        scheduler.add_seq_group(seq_group)
        request_ids.add(str(i))

    # Abort all added seq groups.
    assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
    scheduler.abort_seq_group(request_ids)
    assert scheduler.get_num_unfinished_seq_groups() == 0


74
def test_scheduler_schedule_simple():
75
76
77
    block_size = 4
    num_seq_group = 4
    max_model_len = 16
78
    scheduler_config = SchedulerConfig(
79
80
81
82
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=num_seq_group,
        max_model_len=max_model_len,
83
    )
84
85
86
87
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)
88
    running: list[SequenceGroup] = []
89
90
91

    # Add seq groups to scheduler.
    for i in range(num_seq_group):
92
93
94
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
95
96
97
98
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Schedule seq groups prompts.
99
    num_tokens = block_size * num_seq_group
100
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
101
    assert set(get_sequence_groups(out)) == set(running)
102
    assert out.num_batched_tokens == num_tokens
103
104
105
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
106
    append_new_token(out, 1)
107
108

    # Schedule seq groups generation.
109
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
110
    assert set(get_sequence_groups(out)) == set(running)
111
112
113
114
    assert out.num_batched_tokens == num_seq_group
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
115
116
117
    append_new_token(out, 1)


118
def test_scheduler_prefill_prioritized():
119
120
121
122
    """Verify running batched tokens are not applied to prefill requests."""
    block_size = 4
    max_model_len = 30
    max_batched_num_tokens = 30
123
    scheduler_config = SchedulerConfig(
124
125
126
127
        "generate",
        max_num_batched_tokens=max_batched_num_tokens,
        max_num_seqs=2,
        max_model_len=max_model_len,
128
    )
129
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
130
131
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
132
133
134
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
135
    _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size)
136
137
138
139
140
141
142
    scheduler.add_seq_group(seq_group_a)

    # Schedule seq groups prompts.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_a]

    # Add a new prefill request B.
143
    _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size)
144
145
146
147
148
149
    scheduler.add_seq_group(seq_group_b)

    # Verify prefill requests are prioritized. Since max_batched_num_tokens
    # is 1, new prefill request has to be scheduled first.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_b]
150
151


152
def test_scheduler_schedule_preempt_abort():
153
154
    block_size = 4
    max_model_len = 16
155
    scheduler_config = SchedulerConfig(
156
157
158
159
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=2,
        max_model_len=max_model_len,
160
    )
161
162
163
164
165
166
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 2
    cache_config.num_gpu_blocks = 2
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
167
168
169
170
171
172
    seq_a, seq_group_a = create_dummy_prompt("1",
                                             block_size,
                                             block_size=block_size)
    seq_b, seq_group_b = create_dummy_prompt("2",
                                             block_size,
                                             block_size=block_size)
173
174
175
176
    scheduler.add_seq_group(seq_group_a)
    scheduler.add_seq_group(seq_group_b)

    # Schedule seq groups prompts.
177
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
178
    assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
179
    assert out.num_batched_tokens == block_size * 2  # seq_a and seq_b
180
181
182
183
184
185
186
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 2
    assert scheduler.get_num_unfinished_seq_groups() == 2

    # Append "generated" tokens, allowing the sequence to mark prompt tokens as
    # processed.
187
    append_new_token(out, 1)
188
189

    # Schedule seq groups generation and preempt seq group b.
190
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
191
    assert get_sequence_groups(out) == [seq_group_a]
192
193
194
195
196
    assert out.num_batched_tokens == 1
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 2
197
    assert out.preempted == 1
198
199
200

    # Abort seq group a. Re-schedule seq group b prompt with recomputation.
    scheduler.abort_seq_group("1")
201
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
202
    assert get_sequence_groups(out) == [seq_group_b]
203
    assert out.num_batched_tokens == 5  # 4 prompt + 1 generation.
204
205
206
207
208
209
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 1


210
def test_scheduler_max_seqs():
211
212
213
214
    block_size = 4
    num_seq_group = 4
    max_seq_group = 2
    max_model_len = 16
215
    scheduler_config = SchedulerConfig(
216
217
218
219
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=max_seq_group,
        max_model_len=max_model_len,
220
    )
221
222
223
224
225
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

226
    all_seq_groups: list[SequenceGroup] = []
227
228
    # Add seq groups to scheduler.
    for i in range(num_seq_group):
229
230
231
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
232
233
234
235
236
237
        all_seq_groups.append(seq_group)

    # Append 1 seq group
    scheduler.add_seq_group(all_seq_groups[0])

    # Schedule seq groups prompts.
238
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
239
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
240
    append_new_token(out, 1)
241
242

    # Schedule seq groups generation.
243
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
244
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
245
    append_new_token(out, 1)
246
247
248
249
250
251
252
253

    # Append 2 more seq group
    scheduler.add_seq_group(all_seq_groups[1])
    scheduler.add_seq_group(all_seq_groups[2])

    # Schedule seq groups prompts.
    # Only 1 seq group should be scheduled since max_seq_group is 2
    # and one is prompting.
254
    _, out = schedule_and_update_computed_tokens(scheduler)
255
    assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
256
257


258
def test_scheduler_delay_factor():
259
    block_size = 4
260
    scheduler_config = SchedulerConfig(
261
262
263
264
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=16,
265
        delay_factor=0.5,
266
    )
267
268
269
270
271
272
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # schedule first prompt
273
    seq_group_meta, seq_group = create_dummy_prompt("0",
274
275
                                                    prompt_length=block_size,
                                                    block_size=block_size)
276
    scheduler.add_seq_group(seq_group)
277
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
278
    assert out.num_prefill_groups > 0
279
    assert seq_group_meta[0].request_id == '0'
280
    append_new_token(out, 1)
281
282
283

    # wait for a second before scheduling next prompt
    time.sleep(1)
284
    seq_group_meta, seq_group = create_dummy_prompt("1",
285
286
                                                    prompt_length=block_size,
                                                    block_size=block_size)
287
288
289
    scheduler.add_seq_group(seq_group)

    # second prompt should *not* be scheduled
290
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
291
    assert out.num_prefill_groups == 0
292
    assert seq_group_meta[0].request_id == '0'
293
    append_new_token(out, 1)
294
295
296

    # wait for more than 0.5 second and try again
    time.sleep(0.6)
297
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
298
    assert out.num_prefill_groups > 0
299
    assert seq_group_meta[0].request_id == '1'
300
    append_new_token(out, 1)
301
302


303
304
305
306
307
308
309
310
311
def initialize_scheduler(
    *,
    max_num_seqs=1000,
    max_token_budget=1000,
    max_model_len=1000,
    lora_config=None,
    block_size=4,
    num_cpu_blocks=8,
    num_gpu_blocks=8,
312
313
    enable_prefix_caching=False,
    enable_chunked_prefill=False,
314
315
316
):
    block_size = block_size
    scheduler_config = SchedulerConfig(
317
318
319
320
        "generate",
        max_num_batched_tokens=max_token_budget,
        max_num_seqs=max_num_seqs,
        max_model_len=max_model_len,
321
322
323
324
325
326
327
328
        enable_chunked_prefill=enable_chunked_prefill,
    )
    cache_config = CacheConfig(
        block_size,
        1.0,
        1,
        "auto",
        enable_prefix_caching=enable_prefix_caching,
329
    )
330
331
    cache_config.num_cpu_blocks = num_cpu_blocks
    cache_config.num_gpu_blocks = num_gpu_blocks
332
333
334
335
    scheduler = Scheduler(scheduler_config, cache_config, lora_config)
    return scheduler


336
def create_token_budget(token_budget: int = 10000,
337
338
339
340
341
342
343
                        max_num_seqs: int = 10000) -> SchedulingBudget:
    return SchedulingBudget(
        token_budget=token_budget,
        max_num_seqs=max_num_seqs,
    )


344
345
346
347
348
349
350
351
352
def add_token_budget(budget: SchedulingBudget,
                     num_batched_tokens: int = 0,
                     num_curr_seqs: int = 0):
    mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
    budget.add_num_batched_tokens(mock_seq_group.request_id,
                                  num_batched_tokens)
    budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)


353
def test_prefill_schedule_max_prompt_len():
354
355
356
    """
    Test prompt longer than max_prompt_len is aborted.
    """
357
    block_size = 4
358
    scheduler = initialize_scheduler(max_model_len=30, block_size=block_size)
359
360
361
    _, seq_group = create_dummy_prompt("0",
                                       prompt_length=60,
                                       block_size=block_size)
362
    scheduler.add_seq_group(seq_group)
363
    budget = create_token_budget()
364
365
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
366
367
368
369
370
371
372
    assert len(output.ignored_seq_groups) == 1
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


373
def test_prefill_schedule_token_budget():
374
375
376
    """
    Test token budget respected.
    """
377
    block_size = 4
378
    scheduler = initialize_scheduler(block_size=block_size,
379
380
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
381
382
    budget = create_token_budget(token_budget=0)
    for i in range(2):
383
384
385
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
386
        scheduler.add_seq_group(seq_group)
387
388

    # 0 token budget == nothing is scheduled.
389
390
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
391
392
393
394
395
396
397
398
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 2

    # 60 token budget == 1 request scheduled.
    budget = create_token_budget(token_budget=60)
399
400
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
401
402
403
404
405
406
407
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 60
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 1

    # Test when current_batched_tokens respected.
408
    scheduler = initialize_scheduler(block_size=block_size,
409
410
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
411
412
    budget = create_token_budget(token_budget=60)
    add_token_budget(budget, 30, 0)
413
414
415
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
416
    # Cannot schedule a prompt that doesn't fit the budget.
417
418
419
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
420
421
422
423
424
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 30
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 1
425
426
    budget = create_token_budget(token_budget=90)
    add_token_budget(budget, 30, 0)
427
428
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
429
430
431
432
433
434
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 90
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 0


435
def test_prefill_schedule_max_seqs():
436
437
438
    """
    Test max seq respected.
    """
439
    block_size = 4
440
    scheduler = initialize_scheduler(block_size=block_size,
441
442
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
443
444
    budget = create_token_budget(max_num_seqs=2)
    for i in range(3):
445
446
447
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
448
449
450
        scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
451
452
453
454
455
456
457
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1

    # Verify curr_num_seqs respected.
458
    scheduler.waiting = deque()
459
460
    budget = create_token_budget(max_num_seqs=2)
    add_token_budget(budget, 0, 2)
461
462
463
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
464
465
466
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
467
468
469
470
471
472
473
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1


474
def test_prefill_schedule_max_lora():
475
476
477
    """
    Test max lora is respected and prioritized.
    """
478
    block_size = 4
479
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
480
481
482
483
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
484
    budget = create_token_budget(token_budget=120)
485
    curr_loras: set[int] = set()
486
487
488
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
489
                                           block_size=block_size,
490
491
492
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
493
                                               lora_path="abc"))
494
        scheduler.add_seq_group(seq_group)
495
    # Add two more requests to verify lora is prioritized.
496
    # 0: LoRA, 1: LoRA, 2: regular, 3: regular
497
498
499
500
    # In the first iteration, index 0, 2 is scheduled.
    # If a request is not scheduled because it hits max lora, it is
    # prioritized. Verify that.
    for i in range(2, 4):
501
502
503
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
504
        scheduler.add_seq_group(seq_group)
505
    # Schedule 2 requests (0 and 2)
506
507
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
508
509
510
511
512
513
514
515
516
517
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 2
    assert len(curr_loras) == 1
    # The second lora request is scheduled next as FCFS policy.
    # Reset curr_loras so that it can be scheduled.
    curr_loras = set()
    budget = create_token_budget(token_budget=60)
518
519
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
520
521
522
523
524
525
526
    assert len(output.seq_groups) == 1
    assert output.seq_groups[0].seq_group.request_id == "1"
    assert len(remaining_waiting) == 1
    assert len(curr_loras) == 1
    assert budget.num_batched_tokens == 60


527
def test_prefill_schedule_no_block_manager_capacity():
528
529
530
    """
    Test sequence cannot be scheduled due to block manager has no capacity.
    """
531
    block_size = 4
532
    scheduler = initialize_scheduler(block_size=block_size,
533
534
                                     num_gpu_blocks=128,
                                     num_cpu_blocks=128)
535
536
    budget = create_token_budget()
    for i in range(3):
537
538
539
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
540
        scheduler.add_seq_group(seq_group)
541
542
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
543
544
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
545
546
547
548
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
549
    assert len(remaining_waiting) == 3
550
551
552
553

    scheduler = initialize_scheduler()
    budget = create_token_budget()
    for i in range(3):
554
555
556
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
557
        scheduler.add_seq_group(seq_group)
558
559
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
560
561
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
562
563
564
565
566
567
568
    assert len(output.ignored_seq_groups) == 3
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


569
def test_decode_schedule_preempted():
570
571
572
    """
    Test decodes cannot be scheduled and preempted.
    """
573
    block_size = 4
574
    scheduler = initialize_scheduler(block_size=block_size,
575
576
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
577
578
    curr_loras = None
    for i in range(3):
579
580
581
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
582
        scheduler._allocate_and_set_running(seq_group)
583
        append_new_token_seq_group(60, seq_group, 1)
584
        scheduler._add_seq_group_to_running(seq_group)
585
586
587
588
589
590
591
592
593
594
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group(seq_group, num_lookahead_slots):
        return seq_group.request_id != "1"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group)

    # 1 cannot be scheduled, and the lowest priority (request 2)
    # should be preempted. 1 will also be preempted.
595
    budget = create_token_budget()
596
    output = scheduler._schedule_running(budget, curr_loras)
597
598
    remaining_running = scheduler.running
    assert len(remaining_running) == 0
599
600
601
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
    assert output.decode_seq_groups[0].seq_group.request_id == "0"
602
603
604
    assert len(output.preempted) == 2
    # Verify budgets are updated.
    assert budget.num_batched_tokens == 1
605
606
    # NOTE: When enable_chunk is False, num_seqs budget is not updated.
    # assert budget.num_curr_seqs == 1
607
    # Both should be preempted, not swapped.
608
    assert output.blocks_to_swap_out == []
609
    # Nothing is copied.
610
    assert output.blocks_to_copy == []
611
612


613
def test_schedule_decode_blocks_to_copy_update():
614
615
616
    """
    Verify blocks_to_copy is updated.
    """
617
    block_size = 4
618
    scheduler = initialize_scheduler(block_size=4,
619
620
621
622
623
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       block_size=block_size)
624
    curr_loras = None
625
    scheduler._allocate_and_set_running(seq_group)
626
    append_new_token_seq_group(60, seq_group, 1)
627
    scheduler._add_seq_group_to_running(seq_group)
628
629
630

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
631
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
632
633

    budget = create_token_budget()
634
635
    output = scheduler._schedule_running(budget, curr_loras)
    remaining_running = scheduler.running
636
    assert len(remaining_running) == 0
637
638
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
639
640
641
    assert len(output.preempted) == 0
    assert len(output.swapped_out) == 0
    # Nothing is preempted.
642
    assert output.blocks_to_swap_out == []
643
644
    # Since append_slot returns the source -> dist mapping, it should
    # applied.
645
    assert output.blocks_to_copy == [(2, 3)]
646
647


648
def test_schedule_swapped_max_loras():
649
    block_size = 4
650
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
651
652
653
654
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
655
656
    curr_loras: set[int] = set()
    blocks_to_swap_out: list[tuple[int, int]] = []
657
658
659
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
660
                                           block_size=block_size,
661
662
663
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
664
                                               lora_path="abc"))
665
        scheduler._allocate_and_set_running(seq_group)
666
        append_new_token_seq_group(60, seq_group, 1)
667
        scheduler._swap_out(seq_group, blocks_to_swap_out)
668
        scheduler._add_seq_group_to_swapped(seq_group)
669
670

    budget = create_token_budget()
671
672
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
673
674
675
    assert len(remaining_swapped) == 1
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 1
676
677
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
678
679
680
    assert len(curr_loras) == 1


681
def test_schedule_swapped_cannot_swap_in():
682
    block_size = 4
683
    scheduler = initialize_scheduler(block_size=block_size,
684
685
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
686
    curr_loras = None
687
    blocks_to_swap_out: list[tuple[int, int]] = []
688
689
690
691
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
692
        scheduler._allocate_and_set_running(seq_group)
693
        append_new_token_seq_group(60, seq_group, 1)
694
        scheduler._swap_out(seq_group, blocks_to_swap_out)
695
        scheduler._add_seq_group_to_swapped(seq_group)
696
697
698

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
699
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
700
701
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
702
703
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
704
705
706
    assert len(remaining_swapped) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
707
    assert len(output.decode_seq_groups) == 0
708
709
710
    assert len(output.prefill_seq_groups) == 0


711
def test_infeasible_swap():
712
    block_size = 4
713
    scheduler = initialize_scheduler(block_size=block_size,
714
715
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
716
    curr_loras = None
717
    blocks_to_swap_out: list[tuple[int, int]] = []
718
719
720
721
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
722
723
724
        scheduler._allocate_and_set_running(seq_group)
        append_new_token_seq_group(60, seq_group, 1)
        scheduler._swap_out(seq_group, blocks_to_swap_out)
725
        scheduler._add_seq_group_to_swapped(seq_group)
726
727
728
729
730
731

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
732
733
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
734
735
736
737
738
    assert len(remaining_swapped) == 0
    assert len(output.infeasible_seq_groups) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(output.decode_seq_groups) == 0
739
    assert len(output.prefill_seq_groups) == 0
740
741


742
def test_schedule_swapped_blocks_to_copy():
743
    block_size = 4
744
    scheduler = initialize_scheduler(block_size=block_size,
745
746
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
747
    curr_loras = None
748
749
750
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       block_size=block_size)
751
    scheduler._allocate_and_set_running(seq_group)
752
    append_new_token_seq_group(60, seq_group, 1)
753
    blocks_to_swap_out: list[tuple[int, int]] = []
754
    scheduler._swap_out(seq_group, blocks_to_swap_out)
755
    scheduler._add_seq_group_to_swapped(seq_group)
756
757
758

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
759
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
760
761

    budget = create_token_budget()
762
763
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
764
    assert len(remaining_swapped) == 0
765
766
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
767
    assert output.blocks_to_copy == [(2, 3)]
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811


def test_scheduling_budget():
    TOKEN_BUDGET = 4
    MAX_SEQS = 4
    budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
    assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
    assert budget.remaining_token_budget() == TOKEN_BUDGET

    # Verify add/subtract num batched tokens.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
    # Verify adding another seq group is no-op.
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0

    # Verify add/subtract max seqs.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
    assert budget.num_curr_seqs == 2
    # Verify adding another seq group is no-op.
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 2
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973


@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_prefix_caching_aware_prefills(enable_prefix_caching):
    """
    Test the below scenario:

    For 3 sequences, seqA, seqB, seqC, share the first block as prefix.

    The test verifies the below scenarios:
    1.  SeqA is first scheduled.
    2.  SeqB and SeqC can be prefilled together in a single schedule round
    even though there are not enough token budgets to prefill both without
    considering prefix caching.
    """

    block_size = 4
    max_num_batched_tokens = 12
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=max_num_batched_tokens,
        enable_prefix_caching=enable_prefix_caching,
    )

    seqA_tokens = list(range(8))
    num_shared_tokens = 4
    seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        12, 16))  # Shared prefix first 4.
    seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        16, 20))  # Shared prefix first 4.

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)

    # Schedule seqA prefill.
    scheduler.add_seq_group(seqA_group)
    metas, out, _ = scheduler.schedule()
    assert (len(out.scheduled_seq_groups) == 1
            and out.scheduled_seq_groups[0].seq_group == seqA_group)
    assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)

    # Schedule seqA decode.
    append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
    metas, out, _ = scheduler.schedule()

    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 1

    # Schedule seqB and seqC prefills should work with prefix caching.
    scheduler.add_seq_group(seqB_group)
    scheduler.add_seq_group(seqC_group)
    metas, out, _ = scheduler.schedule()

    if enable_prefix_caching:
        assert len(out.scheduled_seq_groups) == 2
        assert set([
            out.scheduled_seq_groups[0].seq_group,
            out.scheduled_seq_groups[1].seq_group,
        ]) == set([seqB_group, seqC_group])
        assert len(metas) == 2
        for meta in metas:
            assert meta.token_chunk_size == 8
            assert (len(meta.computed_block_nums) == num_shared_tokens //
                    block_size)  # 1 Block for the 8 tokens.
    else:
        assert len(out.scheduled_seq_groups) == 1
        assert len(metas) == 1
        assert metas[0].token_chunk_size == 8
        assert len(metas[0].computed_block_nums) == 0  # No blocks computed.


def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
):
    """
    This test verifies that we don't schedule new prefills if there's already
    a continuous prefill in progress even though the new prefills with shared
    prefix can fit in the token budget:

    - SeqA is being chunked prefill.
    - SeqB with the same prompt shouldn't be scheduled for prefill even though
    there's enough token budget to prefill the cached tokens.
    - Neither should seqC be scheduled.

    - When seqA is in decoding phase, seqB and seqC can be scheduled.
        - Entire seqB should be prefilled since it's a full prefix cache hit.
        - SeqC would be partially prefilled with the prefix shared, and the
        remaining unique tokens would be prefilled (rounded down to be
        block-size aligned).
    """

    block_size = 2
    max_num_batched_tokens = 4
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=100,
        enable_prefix_caching=True,
        enable_chunked_prefill=True,
    )

    seqA_tokens = list(range(8))
    seqB_tokens = seqA_tokens
    seqC_shared_prefix_len = 4
    seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)

    # Chunked prefill seqA.
    scheduler.add_seq_group(seqA_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # seqB should not be scheduled with ongoing prefills.
    scheduler.add_seq_group(seqB_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # both seqB and seqC can now be scheduled with seqA is over.
    # seqA is in decoding phase.
    append_new_token_seq(seqA, 999)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)
    scheduler.add_seq_group(seqC_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 3

    metas = {meta.request_id: meta for meta in metas}
    assert metas[seqA_group.request_id].token_chunk_size == 1  # Decode
    assert (metas[seqB_group.request_id].token_chunk_size == 8
            )  # Fully cached prefill
    assert (
        metas[seqC_group.request_id].token_chunk_size == 6
    ), "A partial prefix of C (4 tokens) should be prefilled, with the "
    "remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
    "then be rounded down to 2 tokens on block size, thus 6 tokens in total."
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043


def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
    """
    Test that the scheduler does not schedule batches with prompt tokens and 
    prompt embeddings co-mingled.
    """
    block_size = 2
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_num_seqs=max_seq_group,
        max_model_len=100,
        enable_prefix_caching=True,
    )

    # the odd indexed inputs should be passed in via embeddings,
    # evens via token_ids
    seq_length = 7
    embedding_size = 5
    num_seqs = 11
    seq_tokens: list[list[int]] = []
    seq_embeds: list[Optional[torch.Tensor]] = []
    for i in range(num_seqs):
        if i % 2:
            seq_tokens.append(list(range(seq_length)))
            seq_embeds.append(None)
        else:
            seq_tokens.append([0] * seq_length)
            seq_embeds.append(torch.rand(embedding_size))

    seq_and_seq_groups = [
        create_dummy_prompt(f"{i}",
                            prompt_tokens=seq_tokens[i],
                            prompt_embeds=seq_embeds[i],
                            block_size=block_size)
        for i in range(len(seq_tokens))
    ]

    for _, seq_group in seq_and_seq_groups:
        scheduler.add_seq_group(seq_group)

    while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
        unfinished_seq_groups = [
            seq_group for _, seq_group in seq_and_seq_groups
            if not seq_group.is_finished()
        ]
        _, out = schedule_and_update_computed_tokens(scheduler)
        assert len(out.scheduled_seq_groups) > 0
        batch_is_prompt_embeds = out.scheduled_seq_groups[
            0].seq_group.uses_prompt_embeds()
        expected_scheduled_seq_groups = [
            seq_group for seq_group in unfinished_seq_groups
            if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
        ]

        # We should have as many scheduled groups as possible, without mixing
        assert len(out.scheduled_seq_groups) == min(
            max_seq_group, len(expected_scheduled_seq_groups))
        assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
                   batch_is_prompt_embeds
                   for scheduled_seq_group in out.scheduled_seq_groups)

        # Finish the scheduled groups
        for scheduled_seq_group in out.scheduled_seq_groups:
            for seq in scheduled_seq_group.seq_group.seqs:
                seq.status = SequenceStatus.FINISHED_STOPPED
        scheduler.free_finished_seq_groups()