test_scheduler.py 40.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import time
4
from collections import deque
5
from typing import Optional
6
from unittest.mock import MagicMock
7

8
import pytest  # noqa
9
import torch
10
from torch import Use  # noqa
11

12
13
14
15
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
16
from vllm.sequence import SequenceGroup, SequenceStatus
17

18
19
20
from .utils import (append_new_token, append_new_token_seq,
                    append_new_token_seq_group, create_dummy_prompt,
                    get_sequence_groups, schedule_and_update_computed_tokens)
21
22


23
def test_scheduler_add_seq_group():
24
    block_size = 4
25
    scheduler_config = SchedulerConfig(
26
27
28
29
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
30
    )
31
    cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
32
33
34
35
36
37
38
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq group to scheduler.
    num_seq_group = 4
    for i in range(num_seq_group):
39
40
41
        _, seq_group = create_dummy_prompt(str(i),
                                           block_size,
                                           block_size=block_size)
42
43
44
45
        scheduler.add_seq_group(seq_group)
        assert scheduler.get_num_unfinished_seq_groups() == i + 1


46
def test_scheduler_abort_seq_group():
47
    block_size = 4
48
    scheduler_config = SchedulerConfig(
49
50
51
52
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=1,
53
    )
54
55
56
57
58
59
60
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 4
    cache_config.num_gpu_blocks = 4
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add multiple seq groups to scheduler.
    num_seq_group = 4
61
    request_ids: set[str] = set()
62
63
64
65
66
67
68
69
70
71
72
    for i in range(num_seq_group):
        _, seq_group = create_dummy_prompt(str(i), block_size)
        scheduler.add_seq_group(seq_group)
        request_ids.add(str(i))

    # Abort all added seq groups.
    assert scheduler.get_num_unfinished_seq_groups() == num_seq_group
    scheduler.abort_seq_group(request_ids)
    assert scheduler.get_num_unfinished_seq_groups() == 0


73
def test_scheduler_schedule_simple():
74
75
76
    block_size = 4
    num_seq_group = 4
    max_model_len = 16
77
    scheduler_config = SchedulerConfig(
78
79
80
81
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=num_seq_group,
        max_model_len=max_model_len,
82
    )
83
84
85
86
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)
87
    running: list[SequenceGroup] = []
88
89
90

    # Add seq groups to scheduler.
    for i in range(num_seq_group):
91
92
93
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
94
95
96
97
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Schedule seq groups prompts.
98
    num_tokens = block_size * num_seq_group
99
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
100
    assert set(get_sequence_groups(out)) == set(running)
101
    assert out.num_batched_tokens == num_tokens
102
103
104
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
105
    append_new_token(out, 1)
106
107

    # Schedule seq groups generation.
108
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
109
    assert set(get_sequence_groups(out)) == set(running)
110
111
112
113
    assert out.num_batched_tokens == num_seq_group
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
114
115
116
    append_new_token(out, 1)


117
def test_scheduler_prefill_prioritized():
118
119
120
121
    """Verify running batched tokens are not applied to prefill requests."""
    block_size = 4
    max_model_len = 30
    max_batched_num_tokens = 30
122
    scheduler_config = SchedulerConfig(
123
124
125
126
        "generate",
        max_num_batched_tokens=max_batched_num_tokens,
        max_num_seqs=2,
        max_model_len=max_model_len,
127
    )
128
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
129
130
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
131
132
133
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
134
    _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size)
135
136
137
138
139
140
141
    scheduler.add_seq_group(seq_group_a)

    # Schedule seq groups prompts.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_a]

    # Add a new prefill request B.
142
    _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size)
143
144
145
146
147
148
    scheduler.add_seq_group(seq_group_b)

    # Verify prefill requests are prioritized. Since max_batched_num_tokens
    # is 1, new prefill request has to be scheduled first.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert get_sequence_groups(out) == [seq_group_b]
149
150


151
def test_scheduler_schedule_preempt_abort():
152
153
    block_size = 4
    max_model_len = 16
154
    scheduler_config = SchedulerConfig(
155
156
157
158
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=2,
        max_model_len=max_model_len,
159
    )
160
161
162
163
164
165
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 2
    cache_config.num_gpu_blocks = 2
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
166
167
168
169
170
171
    seq_a, seq_group_a = create_dummy_prompt("1",
                                             block_size,
                                             block_size=block_size)
    seq_b, seq_group_b = create_dummy_prompt("2",
                                             block_size,
                                             block_size=block_size)
172
173
174
175
    scheduler.add_seq_group(seq_group_a)
    scheduler.add_seq_group(seq_group_b)

    # Schedule seq groups prompts.
176
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
177
    assert get_sequence_groups(out) == [seq_group_a, seq_group_b]
178
    assert out.num_batched_tokens == block_size * 2  # seq_a and seq_b
179
180
181
182
183
184
185
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 2
    assert scheduler.get_num_unfinished_seq_groups() == 2

    # Append "generated" tokens, allowing the sequence to mark prompt tokens as
    # processed.
186
    append_new_token(out, 1)
187
188

    # Schedule seq groups generation and preempt seq group b.
189
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
190
    assert get_sequence_groups(out) == [seq_group_a]
191
192
193
194
195
    assert out.num_batched_tokens == 1
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 2
196
    assert out.preempted == 1
197
198
199

    # Abort seq group a. Re-schedule seq group b prompt with recomputation.
    scheduler.abort_seq_group("1")
200
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
201
    assert get_sequence_groups(out) == [seq_group_b]
202
    assert out.num_batched_tokens == 5  # 4 prompt + 1 generation.
203
204
205
206
207
208
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == 1
    assert scheduler.get_num_unfinished_seq_groups() == 1


209
def test_scheduler_max_seqs():
210
211
212
213
    block_size = 4
    num_seq_group = 4
    max_seq_group = 2
    max_model_len = 16
214
    scheduler_config = SchedulerConfig(
215
216
217
218
        "generate",
        max_num_batched_tokens=64,
        max_num_seqs=max_seq_group,
        max_model_len=max_model_len,
219
    )
220
221
222
223
224
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

225
    all_seq_groups: list[SequenceGroup] = []
226
227
    # Add seq groups to scheduler.
    for i in range(num_seq_group):
228
229
230
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
231
232
233
234
235
236
        all_seq_groups.append(seq_group)

    # Append 1 seq group
    scheduler.add_seq_group(all_seq_groups[0])

    # Schedule seq groups prompts.
237
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
238
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
239
    append_new_token(out, 1)
240
241

    # Schedule seq groups generation.
242
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
243
    assert set(get_sequence_groups(out)) == set([all_seq_groups[0]])
244
    append_new_token(out, 1)
245
246
247
248
249
250
251
252

    # Append 2 more seq group
    scheduler.add_seq_group(all_seq_groups[1])
    scheduler.add_seq_group(all_seq_groups[2])

    # Schedule seq groups prompts.
    # Only 1 seq group should be scheduled since max_seq_group is 2
    # and one is prompting.
253
    _, out = schedule_and_update_computed_tokens(scheduler)
254
    assert set(get_sequence_groups(out)) == set([all_seq_groups[1]])
255
256


257
def test_scheduler_delay_factor():
258
    block_size = 4
259
    scheduler_config = SchedulerConfig(
260
261
262
263
        "generate",
        max_num_batched_tokens=100,
        max_num_seqs=64,
        max_model_len=16,
264
        delay_factor=0.5,
265
    )
266
267
268
269
270
271
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # schedule first prompt
272
    seq_group_meta, seq_group = create_dummy_prompt("0",
273
274
                                                    prompt_length=block_size,
                                                    block_size=block_size)
275
    scheduler.add_seq_group(seq_group)
276
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
277
    assert out.num_prefill_groups > 0
278
    assert seq_group_meta[0].request_id == '0'
279
    append_new_token(out, 1)
280
281
282

    # wait for a second before scheduling next prompt
    time.sleep(1)
283
    seq_group_meta, seq_group = create_dummy_prompt("1",
284
285
                                                    prompt_length=block_size,
                                                    block_size=block_size)
286
287
288
    scheduler.add_seq_group(seq_group)

    # second prompt should *not* be scheduled
289
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
290
    assert out.num_prefill_groups == 0
291
    assert seq_group_meta[0].request_id == '0'
292
    append_new_token(out, 1)
293
294
295

    # wait for more than 0.5 second and try again
    time.sleep(0.6)
296
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
297
    assert out.num_prefill_groups > 0
298
    assert seq_group_meta[0].request_id == '1'
299
    append_new_token(out, 1)
300
301


302
303
304
305
306
307
308
309
310
def initialize_scheduler(
    *,
    max_num_seqs=1000,
    max_token_budget=1000,
    max_model_len=1000,
    lora_config=None,
    block_size=4,
    num_cpu_blocks=8,
    num_gpu_blocks=8,
311
312
    enable_prefix_caching=False,
    enable_chunked_prefill=False,
313
314
315
):
    block_size = block_size
    scheduler_config = SchedulerConfig(
316
317
318
319
        "generate",
        max_num_batched_tokens=max_token_budget,
        max_num_seqs=max_num_seqs,
        max_model_len=max_model_len,
320
321
322
323
324
325
326
327
        enable_chunked_prefill=enable_chunked_prefill,
    )
    cache_config = CacheConfig(
        block_size,
        1.0,
        1,
        "auto",
        enable_prefix_caching=enable_prefix_caching,
328
    )
329
330
    cache_config.num_cpu_blocks = num_cpu_blocks
    cache_config.num_gpu_blocks = num_gpu_blocks
331
332
333
334
    scheduler = Scheduler(scheduler_config, cache_config, lora_config)
    return scheduler


335
def create_token_budget(token_budget: int = 10000,
336
337
338
339
340
341
342
                        max_num_seqs: int = 10000) -> SchedulingBudget:
    return SchedulingBudget(
        token_budget=token_budget,
        max_num_seqs=max_num_seqs,
    )


343
344
345
346
347
348
349
350
351
def add_token_budget(budget: SchedulingBudget,
                     num_batched_tokens: int = 0,
                     num_curr_seqs: int = 0):
    mock_seq_group = create_dummy_prompt('10', prompt_length=60)[1]
    budget.add_num_batched_tokens(mock_seq_group.request_id,
                                  num_batched_tokens)
    budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs)


352
def test_prefill_schedule_max_prompt_len():
353
354
355
    """
    Test prompt longer than max_prompt_len is aborted.
    """
356
    block_size = 4
357
    scheduler = initialize_scheduler(max_model_len=30, block_size=block_size)
358
359
360
    _, seq_group = create_dummy_prompt("0",
                                       prompt_length=60,
                                       block_size=block_size)
361
    scheduler.add_seq_group(seq_group)
362
    budget = create_token_budget()
363
364
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
365
366
367
368
369
370
371
    assert len(output.ignored_seq_groups) == 1
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


372
def test_prefill_schedule_token_budget():
373
374
375
    """
    Test token budget respected.
    """
376
    block_size = 4
377
    scheduler = initialize_scheduler(block_size=block_size,
378
379
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
380
381
    budget = create_token_budget(token_budget=0)
    for i in range(2):
382
383
384
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
385
        scheduler.add_seq_group(seq_group)
386
387

    # 0 token budget == nothing is scheduled.
388
389
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
390
391
392
393
394
395
396
397
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 2

    # 60 token budget == 1 request scheduled.
    budget = create_token_budget(token_budget=60)
398
399
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
400
401
402
403
404
405
406
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 60
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 1

    # Test when current_batched_tokens respected.
407
    scheduler = initialize_scheduler(block_size=block_size,
408
409
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
410
411
    budget = create_token_budget(token_budget=60)
    add_token_budget(budget, 30, 0)
412
413
414
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
415
    # Cannot schedule a prompt that doesn't fit the budget.
416
417
418
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
419
420
421
422
423
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 30
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 1
424
425
    budget = create_token_budget(token_budget=90)
    add_token_budget(budget, 30, 0)
426
427
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
428
429
430
431
432
433
    assert len(output.seq_groups) == 1
    assert budget.num_batched_tokens == 90
    assert budget.num_curr_seqs == 1
    assert len(remaining_waiting) == 0


434
def test_prefill_schedule_max_seqs():
435
436
437
    """
    Test max seq respected.
    """
438
    block_size = 4
439
    scheduler = initialize_scheduler(block_size=block_size,
440
441
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
442
443
    budget = create_token_budget(max_num_seqs=2)
    for i in range(3):
444
445
446
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
447
448
449
        scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
450
451
452
453
454
455
456
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1

    # Verify curr_num_seqs respected.
457
    scheduler.waiting = deque()
458
459
    budget = create_token_budget(max_num_seqs=2)
    add_token_budget(budget, 0, 2)
460
461
462
    _, seq_group = create_dummy_prompt(str(i),
                                       prompt_length=60,
                                       block_size=block_size)
463
464
465
    scheduler.add_seq_group(seq_group)
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
466
467
468
469
470
471
472
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 1


473
def test_prefill_schedule_max_lora():
474
475
476
    """
    Test max lora is respected and prioritized.
    """
477
    block_size = 4
478
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
479
480
481
482
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
483
    budget = create_token_budget(token_budget=120)
484
    curr_loras: set[int] = set()
485
486
487
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
488
                                           block_size=block_size,
489
490
491
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
492
                                               lora_path="abc"))
493
        scheduler.add_seq_group(seq_group)
494
    # Add two more requests to verify lora is prioritized.
495
    # 0: LoRA, 1: LoRA, 2: regular, 3: regular
496
497
498
499
    # In the first iteration, index 0, 2 is scheduled.
    # If a request is not scheduled because it hits max lora, it is
    # prioritized. Verify that.
    for i in range(2, 4):
500
501
502
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
503
        scheduler.add_seq_group(seq_group)
504
    # Schedule 2 requests (0 and 2)
505
506
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
507
508
509
510
511
512
513
514
515
516
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 2
    assert budget.num_batched_tokens == 120
    assert budget.num_curr_seqs == 2
    assert len(remaining_waiting) == 2
    assert len(curr_loras) == 1
    # The second lora request is scheduled next as FCFS policy.
    # Reset curr_loras so that it can be scheduled.
    curr_loras = set()
    budget = create_token_budget(token_budget=60)
517
518
    output = scheduler._schedule_prefills(budget, curr_loras)
    remaining_waiting = scheduler.waiting
519
520
521
522
523
524
525
    assert len(output.seq_groups) == 1
    assert output.seq_groups[0].seq_group.request_id == "1"
    assert len(remaining_waiting) == 1
    assert len(curr_loras) == 1
    assert budget.num_batched_tokens == 60


526
def test_prefill_schedule_no_block_manager_capacity():
527
528
529
    """
    Test sequence cannot be scheduled due to block manager has no capacity.
    """
530
    block_size = 4
531
    scheduler = initialize_scheduler(block_size=block_size,
532
533
                                     num_gpu_blocks=128,
                                     num_cpu_blocks=128)
534
535
    budget = create_token_budget()
    for i in range(3):
536
537
538
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
539
        scheduler.add_seq_group(seq_group)
540
541
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER
542
543
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
544
545
546
547
    assert len(output.ignored_seq_groups) == 0
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
548
    assert len(remaining_waiting) == 3
549
550
551
552

    scheduler = initialize_scheduler()
    budget = create_token_budget()
    for i in range(3):
553
554
555
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
556
        scheduler.add_seq_group(seq_group)
557
558
    scheduler.block_manager.can_allocate = MagicMock()
    scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER
559
560
    output = scheduler._schedule_prefills(budget, None)
    remaining_waiting = scheduler.waiting
561
562
563
564
565
566
567
    assert len(output.ignored_seq_groups) == 3
    assert len(output.seq_groups) == 0
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(remaining_waiting) == 0


568
def test_decode_schedule_preempted():
569
570
571
    """
    Test decodes cannot be scheduled and preempted.
    """
572
    block_size = 4
573
    scheduler = initialize_scheduler(block_size=block_size,
574
575
                                     num_cpu_blocks=64,
                                     num_gpu_blocks=64)
576
577
    curr_loras = None
    for i in range(3):
578
579
580
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
581
        scheduler._allocate_and_set_running(seq_group)
582
        append_new_token_seq_group(60, seq_group, 1)
583
        scheduler._add_seq_group_to_running(seq_group)
584
585
586
587
588
589
590
591
592
593
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group(seq_group, num_lookahead_slots):
        return seq_group.request_id != "1"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group)

    # 1 cannot be scheduled, and the lowest priority (request 2)
    # should be preempted. 1 will also be preempted.
594
    budget = create_token_budget()
595
596
    output = scheduler._schedule_running(budget, curr_loras)
    remainig_running = scheduler.running
597
    assert len(remainig_running) == 0
598
599
600
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
    assert output.decode_seq_groups[0].seq_group.request_id == "0"
601
602
603
    assert len(output.preempted) == 2
    # Verify budgets are updated.
    assert budget.num_batched_tokens == 1
604
605
    # NOTE: When enable_chunk is False, num_seqs budget is not updated.
    # assert budget.num_curr_seqs == 1
606
    # Both should be preempted, not swapped.
607
    assert output.blocks_to_swap_out == []
608
    # Nothing is copied.
609
    assert output.blocks_to_copy == []
610
611


612
def test_schedule_decode_blocks_to_copy_update():
613
614
615
    """
    Verify blocks_to_copy is updated.
    """
616
    block_size = 4
617
    scheduler = initialize_scheduler(block_size=4,
618
619
620
621
622
                                     num_cpu_blocks=16,
                                     num_gpu_blocks=16)
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       block_size=block_size)
623
    curr_loras = None
624
    scheduler._allocate_and_set_running(seq_group)
625
    append_new_token_seq_group(60, seq_group, 1)
626
    scheduler._add_seq_group_to_running(seq_group)
627
628
629

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
630
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
631
632

    budget = create_token_budget()
633
634
    output = scheduler._schedule_running(budget, curr_loras)
    remaining_running = scheduler.running
635
    assert len(remaining_running) == 0
636
637
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
638
639
640
    assert len(output.preempted) == 0
    assert len(output.swapped_out) == 0
    # Nothing is preempted.
641
    assert output.blocks_to_swap_out == []
642
643
    # Since append_slot returns the source -> dist mapping, it should
    # applied.
644
    assert output.blocks_to_copy == [(2, 3)]
645
646


647
def test_schedule_swapped_max_loras():
648
    block_size = 4
649
    lora_config = LoRAConfig(max_lora_rank=8, max_loras=1)
650
651
652
653
    scheduler = initialize_scheduler(lora_config=lora_config,
                                     block_size=block_size,
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
654
655
    curr_loras: set[int] = set()
    blocks_to_swap_out: list[tuple[int, int]] = []
656
657
658
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
659
                                           block_size=block_size,
660
661
662
                                           lora_request=LoRARequest(
                                               lora_name=str(i),
                                               lora_int_id=i + 1,
663
                                               lora_path="abc"))
664
        scheduler._allocate_and_set_running(seq_group)
665
        append_new_token_seq_group(60, seq_group, 1)
666
        scheduler._swap_out(seq_group, blocks_to_swap_out)
667
        scheduler._add_seq_group_to_swapped(seq_group)
668
669

    budget = create_token_budget()
670
671
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
672
673
674
    assert len(remaining_swapped) == 1
    assert budget.num_batched_tokens == 1
    assert budget.num_curr_seqs == 1
675
676
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
677
678
679
    assert len(curr_loras) == 1


680
def test_schedule_swapped_cannot_swap_in():
681
    block_size = 4
682
    scheduler = initialize_scheduler(block_size=block_size,
683
684
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
685
    curr_loras = None
686
    blocks_to_swap_out: list[tuple[int, int]] = []
687
688
689
690
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
691
        scheduler._allocate_and_set_running(seq_group)
692
        append_new_token_seq_group(60, seq_group, 1)
693
        scheduler._swap_out(seq_group, blocks_to_swap_out)
694
        scheduler._add_seq_group_to_swapped(seq_group)
695
696
697

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
698
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER
699
700
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
701
702
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
703
704
705
    assert len(remaining_swapped) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
706
    assert len(output.decode_seq_groups) == 0
707
708
709
    assert len(output.prefill_seq_groups) == 0


710
def test_infeasible_swap():
711
    block_size = 4
712
    scheduler = initialize_scheduler(block_size=block_size,
713
714
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
715
    curr_loras = None
716
    blocks_to_swap_out: list[tuple[int, int]] = []
717
718
719
720
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
721
722
723
        scheduler._allocate_and_set_running(seq_group)
        append_new_token_seq_group(60, seq_group, 1)
        scheduler._swap_out(seq_group, blocks_to_swap_out)
724
        scheduler._add_seq_group_to_swapped(seq_group)
725
726
727
728
729
730

    # The last request should be swapped out.
    scheduler.block_manager.can_swap_in = MagicMock()
    scheduler.block_manager.can_swap_in.return_value = AllocStatus.NEVER
    # Since we cannot swap in, none of the requests are swapped in.
    budget = create_token_budget()
731
732
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
733
734
735
736
737
    assert len(remaining_swapped) == 0
    assert len(output.infeasible_seq_groups) == 2
    assert budget.num_batched_tokens == 0
    assert budget.num_curr_seqs == 0
    assert len(output.decode_seq_groups) == 0
738
    assert len(output.prefill_seq_groups) == 0
739
740


741
def test_schedule_swapped_blocks_to_copy():
742
    block_size = 4
743
    scheduler = initialize_scheduler(block_size=block_size,
744
745
                                     num_cpu_blocks=32,
                                     num_gpu_blocks=32)
746
    curr_loras = None
747
748
749
    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       block_size=block_size)
750
    scheduler._allocate_and_set_running(seq_group)
751
    append_new_token_seq_group(60, seq_group, 1)
752
    blocks_to_swap_out: list[tuple[int, int]] = []
753
    scheduler._swap_out(seq_group, blocks_to_swap_out)
754
    scheduler._add_seq_group_to_swapped(seq_group)
755
756
757

    # The last request should be swapped out.
    scheduler.block_manager.append_slots = MagicMock()
758
    scheduler.block_manager.append_slots.return_value = [(2, 3)]
759
760

    budget = create_token_budget()
761
762
    output = scheduler._schedule_swapped(budget, curr_loras)
    remaining_swapped = scheduler.swapped
763
    assert len(remaining_swapped) == 0
764
765
    assert len(output.decode_seq_groups) == 1
    assert len(output.prefill_seq_groups) == 0
766
    assert output.blocks_to_copy == [(2, 3)]
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810


def test_scheduling_budget():
    TOKEN_BUDGET = 4
    MAX_SEQS = 4
    budget = SchedulingBudget(token_budget=TOKEN_BUDGET, max_num_seqs=MAX_SEQS)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=1)
    assert budget.can_schedule(num_new_tokens=4, num_new_seqs=4)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=5)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=5, num_new_seqs=5)
    assert budget.remaining_token_budget() == TOKEN_BUDGET

    # Verify add/subtract num batched tokens.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    assert budget.can_schedule(num_new_tokens=2, num_new_seqs=1)
    assert not budget.can_schedule(num_new_tokens=3, num_new_seqs=1)
    # Verify adding another seq group is no-op.
    budget.add_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 2
    assert budget.num_batched_tokens == 2
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0
    budget.subtract_num_batched_tokens(seq_group.request_id, 2)
    assert budget.remaining_token_budget() == 4
    assert budget.num_batched_tokens == 0

    # Verify add/subtract max seqs.
    _, seq_group = create_dummy_prompt("1", 3)
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.can_schedule(num_new_tokens=1, num_new_seqs=2)
    assert not budget.can_schedule(num_new_tokens=1, num_new_seqs=3)
    assert budget.num_curr_seqs == 2
    # Verify adding another seq group is no-op.
    budget.add_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 2
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
    budget.subtract_num_seqs(seq_group.request_id, 2)
    assert budget.num_curr_seqs == 0
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972


@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_prefix_caching_aware_prefills(enable_prefix_caching):
    """
    Test the below scenario:

    For 3 sequences, seqA, seqB, seqC, share the first block as prefix.

    The test verifies the below scenarios:
    1.  SeqA is first scheduled.
    2.  SeqB and SeqC can be prefilled together in a single schedule round
    even though there are not enough token budgets to prefill both without
    considering prefix caching.
    """

    block_size = 4
    max_num_batched_tokens = 12
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=max_num_batched_tokens,
        enable_prefix_caching=enable_prefix_caching,
    )

    seqA_tokens = list(range(8))
    num_shared_tokens = 4
    seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        12, 16))  # Shared prefix first 4.
    seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range(
        16, 20))  # Shared prefix first 4.

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)

    # Schedule seqA prefill.
    scheduler.add_seq_group(seqA_group)
    metas, out, _ = scheduler.schedule()
    assert (len(out.scheduled_seq_groups) == 1
            and out.scheduled_seq_groups[0].seq_group == seqA_group)
    assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens)

    # Schedule seqA decode.
    append_new_token_seq_group(len(seqA_tokens), seqA_group, 999)
    metas, out, _ = scheduler.schedule()

    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 1

    # Schedule seqB and seqC prefills should work with prefix caching.
    scheduler.add_seq_group(seqB_group)
    scheduler.add_seq_group(seqC_group)
    metas, out, _ = scheduler.schedule()

    if enable_prefix_caching:
        assert len(out.scheduled_seq_groups) == 2
        assert set([
            out.scheduled_seq_groups[0].seq_group,
            out.scheduled_seq_groups[1].seq_group,
        ]) == set([seqB_group, seqC_group])
        assert len(metas) == 2
        for meta in metas:
            assert meta.token_chunk_size == 8
            assert (len(meta.computed_block_nums) == num_shared_tokens //
                    block_size)  # 1 Block for the 8 tokens.
    else:
        assert len(out.scheduled_seq_groups) == 1
        assert len(metas) == 1
        assert metas[0].token_chunk_size == 8
        assert len(metas[0].computed_block_nums) == 0  # No blocks computed.


def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
):
    """
    This test verifies that we don't schedule new prefills if there's already
    a continuous prefill in progress even though the new prefills with shared
    prefix can fit in the token budget:

    - SeqA is being chunked prefill.
    - SeqB with the same prompt shouldn't be scheduled for prefill even though
    there's enough token budget to prefill the cached tokens.
    - Neither should seqC be scheduled.

    - When seqA is in decoding phase, seqB and seqC can be scheduled.
        - Entire seqB should be prefilled since it's a full prefix cache hit.
        - SeqC would be partially prefilled with the prefix shared, and the
        remaining unique tokens would be prefilled (rounded down to be
        block-size aligned).
    """

    block_size = 2
    max_num_batched_tokens = 4
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_token_budget=max_num_batched_tokens,
        max_num_seqs=max_seq_group,
        max_model_len=100,
        enable_prefix_caching=True,
        enable_chunked_prefill=True,
    )

    seqA_tokens = list(range(8))
    seqB_tokens = seqA_tokens
    seqC_shared_prefix_len = 4
    seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20))

    seqA, seqA_group = create_dummy_prompt("0",
                                           prompt_tokens=seqA_tokens,
                                           block_size=block_size)
    seqB, seqB_group = create_dummy_prompt("1",
                                           prompt_tokens=seqB_tokens,
                                           block_size=block_size)

    # Chunked prefill seqA.
    scheduler.add_seq_group(seqA_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # seqB should not be scheduled with ongoing prefills.
    scheduler.add_seq_group(seqB_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.scheduled_seq_groups[0].seq_group == seqA_group
    assert out.scheduled_seq_groups[0].token_chunk_size == 4

    # both seqB and seqC can now be scheduled with seqA is over.
    # seqA is in decoding phase.
    append_new_token_seq(seqA, 999)
    seqC, seqC_group = create_dummy_prompt("2",
                                           prompt_tokens=seqC_tokens,
                                           block_size=block_size)
    scheduler.add_seq_group(seqC_group)
    metas, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 3

    metas = {meta.request_id: meta for meta in metas}
    assert metas[seqA_group.request_id].token_chunk_size == 1  # Decode
    assert (metas[seqB_group.request_id].token_chunk_size == 8
            )  # Fully cached prefill
    assert (
        metas[seqC_group.request_id].token_chunk_size == 6
    ), "A partial prefix of C (4 tokens) should be prefilled, with the "
    "remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
    "then be rounded down to 2 tokens on block size, thus 6 tokens in total."
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042


def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
    """
    Test that the scheduler does not schedule batches with prompt tokens and 
    prompt embeddings co-mingled.
    """
    block_size = 2
    max_seq_group = 3
    scheduler = initialize_scheduler(
        block_size=block_size,
        num_cpu_blocks=16,
        num_gpu_blocks=16,
        max_num_seqs=max_seq_group,
        max_model_len=100,
        enable_prefix_caching=True,
    )

    # the odd indexed inputs should be passed in via embeddings,
    # evens via token_ids
    seq_length = 7
    embedding_size = 5
    num_seqs = 11
    seq_tokens: list[list[int]] = []
    seq_embeds: list[Optional[torch.Tensor]] = []
    for i in range(num_seqs):
        if i % 2:
            seq_tokens.append(list(range(seq_length)))
            seq_embeds.append(None)
        else:
            seq_tokens.append([0] * seq_length)
            seq_embeds.append(torch.rand(embedding_size))

    seq_and_seq_groups = [
        create_dummy_prompt(f"{i}",
                            prompt_tokens=seq_tokens[i],
                            prompt_embeds=seq_embeds[i],
                            block_size=block_size)
        for i in range(len(seq_tokens))
    ]

    for _, seq_group in seq_and_seq_groups:
        scheduler.add_seq_group(seq_group)

    while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
        unfinished_seq_groups = [
            seq_group for _, seq_group in seq_and_seq_groups
            if not seq_group.is_finished()
        ]
        _, out = schedule_and_update_computed_tokens(scheduler)
        assert len(out.scheduled_seq_groups) > 0
        batch_is_prompt_embeds = out.scheduled_seq_groups[
            0].seq_group.uses_prompt_embeds()
        expected_scheduled_seq_groups = [
            seq_group for seq_group in unfinished_seq_groups
            if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
        ]

        # We should have as many scheduled groups as possible, without mixing
        assert len(out.scheduled_seq_groups) == min(
            max_seq_group, len(expected_scheduled_seq_groups))
        assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
                   batch_is_prompt_embeds
                   for scheduled_seq_group in out.scheduled_seq_groups)

        # Finish the scheduled groups
        for scheduled_seq_group in out.scheduled_seq_groups:
            for seq in scheduled_seq_group.seq_group.seqs:
                seq.status = SequenceStatus.FINISHED_STOPPED
        scheduler.free_finished_seq_groups()