Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
93deb0b3
Unverified
Commit
93deb0b3
authored
Apr 01, 2024
by
Cade Daniel
Committed by
GitHub
Apr 01, 2024
Browse files
[Speculative decoding 4/9] Lookahead scheduling for speculative decoding (#3250)
parent
ccb58b23
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
579 additions
and
73 deletions
+579
-73
tests/core/block/e2e/test_correctness.py
tests/core/block/e2e/test_correctness.py
+153
-0
tests/core/block/test_block_manager_v2.py
tests/core/block/test_block_manager_v2.py
+103
-0
tests/core/block/test_block_table.py
tests/core/block/test_block_table.py
+75
-0
tests/core/test_block_manager.py
tests/core/test_block_manager.py
+12
-12
tests/core/utils.py
tests/core/utils.py
+2
-2
vllm/config.py
vllm/config.py
+16
-2
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+52
-6
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+24
-9
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+59
-23
vllm/core/interfaces.py
vllm/core/interfaces.py
+10
-6
vllm/core/scheduler.py
vllm/core/scheduler.py
+63
-12
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+10
-1
No files found.
tests/core/block/e2e/test_correctness.py
View file @
93deb0b3
...
@@ -77,6 +77,159 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
...
@@ -77,6 +77,159 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
assert
baseline_token_ids
==
test_token_ids
assert
baseline_token_ids
==
test_token_ids
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
"model"
:
"facebook/opt-125m"
,
# skip cuda graph creation for fast test.
"enforce_eager"
:
True
,
# Use a large block size to trigger more copy-on-writes.
"block_size"
:
32
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"use_v2_block_manager"
:
False
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_v1_v2_greedy_equality_with_cow
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
):
"""Verify beam search equality with block manager v1 and v2.
This requires copy-on-writes; if the v1 and v2 output is the same, then
we have some confidence cow is working.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
use_beam_search
=
True
,
best_of
=
2
,
)
print
(
'Getting token ids from block manager v1'
)
baseline_token_ids
=
get_token_ids_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
print
(
'Getting token ids from block manager v2'
)
test_token_ids
=
get_token_ids_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
for
expected_token_ids
,
actual_token_ids
in
zip
(
baseline_token_ids
,
test_token_ids
):
assert
expected_token_ids
==
actual_token_ids
assert
baseline_token_ids
==
test_token_ids
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
"model"
:
"facebook/opt-125m"
,
# Our prompts will generate 128 tokens; since the prompts themselves are
# small, we don't need much KV space beyond 128.
"max_model_len"
:
160
,
# skip cuda graph creation for fast test.
"enforce_eager"
:
True
,
# Lookahead scheduling only supported in v2 block manager.
"use_v2_block_manager"
:
True
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[
{
"block_size"
:
16
,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 8 = 128/block_size
"forced_num_gpu_blocks"
:
2
*
(
8
+
1
),
},
{
"block_size"
:
8
,
# Allow only 2 sequences of ~128 tokens in worst case.
# Note 16 = 128/block_size
"forced_num_gpu_blocks"
:
2
*
(
16
+
1
),
}
])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"num_lookahead_slots"
:
0
,
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
# We run one test with block_size < lookahead_slots, one test with
# block_size > lookahead_slots
"num_lookahead_slots"
:
10
,
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_lookahead_greedy_equality_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
):
"""Verify vLLM produces the same output with greedy sampling, when lookahead
scheduling is used vs. not.
Lookahead scheduling is not expected to modify the output, as it simply
allocates empty slots ahead of the known token ids in a sliding fashion.
This test constrains the total number of blocks to force preemption. It also
varies the block size so that the lookahead size is less than and greater
than the block size.
"""
output_len
=
128
temperature
=
0.0
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
print
(
'Getting token ids without lookahead scheduling'
)
baseline_token_ids
=
get_token_ids_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
print
(
'Getting token ids with lookahead scheduling'
)
test_token_ids
=
get_token_ids_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
for
expected_token_ids
,
actual_token_ids
in
zip
(
baseline_token_ids
,
test_token_ids
):
assert
expected_token_ids
==
actual_token_ids
assert
baseline_token_ids
==
test_token_ids
def
get_token_ids_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
):
def
get_token_ids_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
):
for
llm
in
llm_generator
:
for
llm
in
llm_generator
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
...
...
tests/core/block/test_block_
space_
manager.py
→
tests/core/block/test_block_manager
_v2
.py
View file @
93deb0b3
...
@@ -2,6 +2,8 @@ import pytest
...
@@ -2,6 +2,8 @@ import pytest
from
vllm.core.block_manager_v2
import
BlockSpaceManagerV2
from
vllm.core.block_manager_v2
import
BlockSpaceManagerV2
from
vllm.core.interfaces
import
AllocStatus
from
vllm.core.interfaces
import
AllocStatus
from
vllm.sequence
import
Logprob
,
SequenceStatus
from
vllm.utils
import
chunk_list
from
..utils
import
create_seq_group
from
..utils
import
create_seq_group
...
@@ -29,7 +31,7 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
...
@@ -29,7 +31,7 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
for
num_prompt_blocks
in
range
(
1
,
num_gpu_blocks
-
num_output_blocks
):
for
num_prompt_blocks
in
range
(
1
,
num_gpu_blocks
-
num_output_blocks
):
seq_group
=
create_seq_group
(
seq_group
=
create_seq_group
(
seq_prompt_len
s
=
block_size
*
num_prompt_blocks
,
seq_prompt_len
=
block_size
*
num_prompt_blocks
,
seq_output_lens
=
[
seq_output_lens
=
[
block_size
*
num_output_blocks_per_seq
block_size
*
num_output_blocks_per_seq
for
_
in
range
(
num_seqs_per_group
)
for
_
in
range
(
num_seqs_per_group
)
...
@@ -48,3 +50,54 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
...
@@ -48,3 +50,54 @@ def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int,
assert
can_allocate_result
==
AllocStatus
.
OK
assert
can_allocate_result
==
AllocStatus
.
OK
else
:
else
:
assert
can_allocate_result
==
AllocStatus
.
LATER
assert
can_allocate_result
==
AllocStatus
.
LATER
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"prompt_len"
,
[
1
,
7
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_slots_to_append"
,
[
1
,
8
,
129
])
@
pytest
.
mark
.
parametrize
(
"num_lookahead_slots"
,
[
0
,
10
])
def
test_append_slots
(
block_size
,
prompt_len
,
num_slots_to_append
,
num_lookahead_slots
):
"""Verify append_slots consumes the correct number of blocks from the block
table.
"""
num_gpu_blocks
=
1024
watermark
=
0.1
block_manager
=
BlockSpaceManagerV2
(
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
0
,
watermark
=
watermark
,
)
seq_group
=
create_seq_group
(
seq_prompt_len
=
prompt_len
,
seq_output_lens
=
[
0
],
)
# Allocate seq
assert
block_manager
.
can_allocate
(
seq_group
)
block_manager
.
allocate
(
seq_group
)
# Seq seq to RUNNING
seq
=
seq_group
.
get_seqs
()[
0
]
seq
.
status
=
SequenceStatus
.
RUNNING
# Append tokens to the sequeqnce
for
token_id
in
range
(
num_slots_to_append
):
seq
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
# Append slots for new tokens and lookahead slots.
free_blocks_before_append
=
block_manager
.
get_num_free_gpu_blocks
()
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
num_consumed_blocks
=
(
free_blocks_before_append
-
block_manager
.
get_num_free_gpu_blocks
())
# Expect consumed blocks to be new blocks required to support the new slots.
expected_consumed_blocks
=
len
(
chunk_list
(
list
(
range
(
prompt_len
+
num_slots_to_append
+
num_lookahead_slots
)),
block_size
))
-
len
(
chunk_list
(
list
(
range
(
prompt_len
)),
block_size
))
assert
num_consumed_blocks
==
expected_consumed_blocks
tests/core/block/test_block_table.py
View file @
93deb0b3
...
@@ -498,3 +498,78 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,
...
@@ -498,3 +498,78 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,
# After free, expect all blocks to be freed.
# After free, expect all blocks to be freed.
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"num_new_tokens"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"num_lookahead_slots"
,
[
1
,
7
,
8
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_num_blocks_touched_by_append_slots
(
block_size
:
int
,
sequence_len
:
int
,
num_new_tokens
:
int
,
num_lookahead_slots
:
int
,
allocator_type
:
str
):
"""Verify correct calculation of get_num_blocks_touched_by_append_slots.
This is done by using copy-on-write, which requires any modified block to
be copied before write if the refcount > 1. We set the refcount>1 by forking
a sequence, then measure the free blocks before and after an append. If the
number of consumed blocks equals what `get_num_blocks_touched_by_append_
slots` returns, then the calculation is correct.
"""
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
0
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
token_ids_to_append
=
list
(
range
(
num_new_tokens
))
block_table
=
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
)
block_table
.
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
# Add lookahead before fork so both sequences have the same lookahead
# blocks.
block_table
.
ensure_num_empty_slots
(
num_empty_slots
=
num_lookahead_slots
)
# Fork sequence so that every block has refcount > 1.
_
=
block_table
.
fork
()
# Determine how many blocks should be touched.
expected_num_touched_blocks
=
(
block_table
.
get_num_blocks_touched_by_append_slots
(
token_ids
=
token_ids_to_append
,
num_lookahead_slots
=
num_lookahead_slots
))
# Measure how many blocks are touched by measuring num_free_blocks before
# and after the append.
#
# We expect append_token_ids to CoW all mutated blocks that have refcount>1.
num_free_blocks_before_append
=
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
block_table
.
append_token_ids
(
token_ids_to_append
,
num_lookahead_slots
)
num_consumed_blocks
=
(
num_free_blocks_before_append
-
allocator
.
get_num_free_blocks
(
Device
.
GPU
))
# TODO(cade) ensure equality when num_lookahead_slots > 0.
# The reason we have < is because lookahead blocks are not copied eagerly;
# they are copied on first write. This will cause issues for beam search +
# speculative decoding. This is acceptable for now as it is a large effort
# to combine the two. To fix this, we can ensure single sequence ownership
# of lookahead blocks by appending empty slots to each block, which will
# trigger the CoW.
#
# Until then, we can accept that the consumed tokens are <= the expected
# tokens when appending with lookahead.
if
num_lookahead_slots
>
0
:
assert
num_consumed_blocks
<=
expected_num_touched_blocks
else
:
assert
num_consumed_blocks
==
expected_num_touched_blocks
tests/core/test_block_manager.py
View file @
93deb0b3
...
@@ -103,9 +103,9 @@ def test_append_slot_single_seq():
...
@@ -103,9 +103,9 @@ def test_append_slot_single_seq():
block_manager
.
allocate
(
seq_group
)
block_manager
.
allocate
(
seq_group
)
# Nothing to append. Sequence has no new logical blocks.
# Nothing to append. Sequence has no new logical blocks.
assert
block_manager
.
can_append_slot
(
seq_group
)
assert
block_manager
.
can_append_slot
s
(
seq_group
)
before_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
before_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
assert
not
block_manager
.
append_slot
(
prompt
)
assert
not
block_manager
.
append_slot
s
(
prompt
)
after_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
after_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
assert
before_blocks
==
after_blocks
assert
before_blocks
==
after_blocks
...
@@ -114,9 +114,9 @@ def test_append_slot_single_seq():
...
@@ -114,9 +114,9 @@ def test_append_slot_single_seq():
token_id
=
i
+
5
token_id
=
i
+
5
prompt
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
prompt
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
assert
block_manager
.
can_append_slot
(
seq_group
)
assert
block_manager
.
can_append_slot
s
(
seq_group
)
before_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
before_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
assert
not
block_manager
.
append_slot
(
prompt
)
assert
not
block_manager
.
append_slot
s
(
prompt
)
after_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
after_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
assert
before_blocks
-
after_blocks
==
1
assert
before_blocks
-
after_blocks
==
1
...
@@ -150,13 +150,13 @@ def test_append_slot_cow():
...
@@ -150,13 +150,13 @@ def test_append_slot_cow():
child
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
child
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
block_manager
.
fork
(
prompt
,
child
)
block_manager
.
fork
(
prompt
,
child
)
assert
block_manager
.
can_append_slot
(
seq_group
)
assert
block_manager
.
can_append_slot
s
(
seq_group
)
before_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
before_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
maybe_src_dst_block
=
block_manager
.
append_slot
(
child
)
cows
=
block_manager
.
append_slot
s
(
child
)
assert
maybe_src_dst_block
is
not
None
assert
cows
src_block
,
dst_block
=
maybe_src_dst_block
for
src_block
,
dst_block
s
in
cows
.
items
():
assert
src_block
!=
dst_block
assert
src_block
not
in
dst_block
s
after_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
after_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
assert
before_blocks
-
after_blocks
==
1
assert
before_blocks
-
after_blocks
==
1
...
@@ -184,7 +184,7 @@ def test_fork():
...
@@ -184,7 +184,7 @@ def test_fork():
token_id
=
4
token_id
=
4
# Append token to child. Block is shared so copy on write occurs.
# Append token to child. Block is shared so copy on write occurs.
child
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
child
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
block_manager
.
append_slot
(
child
)
block_manager
.
append_slot
s
(
child
)
assert
block_manager
.
get_block_table
(
assert
block_manager
.
get_block_table
(
prompt
)
!=
block_manager
.
get_block_table
(
child
)
prompt
)
!=
block_manager
.
get_block_table
(
child
)
...
@@ -325,7 +325,7 @@ def test_sliding_window_multi_seq():
...
@@ -325,7 +325,7 @@ def test_sliding_window_multi_seq():
token_id
=
4
token_id
=
4
# Append token to child. Block is shared so copy on write occurs.
# Append token to child. Block is shared so copy on write occurs.
child
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
child
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
block_manager
.
append_slot
(
child
)
block_manager
.
append_slot
s
(
child
)
# assert the number of blocks allocated is correct
# assert the number of blocks allocated is correct
# we will use now one block more. Each seq will use 2 blocks,
# we will use now one block more. Each seq will use 2 blocks,
...
@@ -335,7 +335,7 @@ def test_sliding_window_multi_seq():
...
@@ -335,7 +335,7 @@ def test_sliding_window_multi_seq():
token_id
=
5
token_id
=
5
parent
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
parent
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
block_manager
.
append_slot
(
parent
)
block_manager
.
append_slot
s
(
parent
)
# assert the number of blocks allocated is correct
# assert the number of blocks allocated is correct
# no change, because both sequences are still just sharing one block
# no change, because both sequences are still just sharing one block
...
...
tests/core/utils.py
View file @
93deb0b3
...
@@ -24,7 +24,7 @@ def create_dummy_prompt(
...
@@ -24,7 +24,7 @@ def create_dummy_prompt(
def
create_seq_group
(
def
create_seq_group
(
seq_prompt_len
s
=
1024
,
seq_prompt_len
=
1024
,
seq_output_lens
=
(
128
,
),
seq_output_lens
=
(
128
,
),
request_id
=
'0'
,
request_id
=
'0'
,
seq_id_start
=
0
,
seq_id_start
=
0
,
...
@@ -32,7 +32,7 @@ def create_seq_group(
...
@@ -32,7 +32,7 @@ def create_seq_group(
assert
len
(
seq_output_lens
)
>
0
assert
len
(
seq_output_lens
)
>
0
prompt_token_ids
=
[
0
]
*
seq_prompt_len
s
prompt_token_ids
=
[
0
]
*
seq_prompt_len
seqs
=
[]
seqs
=
[]
for
seq_id_offset
,
output_len
in
enumerate
(
seq_output_lens
):
for
seq_id_offset
,
output_len
in
enumerate
(
seq_output_lens
):
...
...
vllm/config.py
View file @
93deb0b3
...
@@ -530,9 +530,13 @@ class SchedulerConfig:
...
@@ -530,9 +530,13 @@ class SchedulerConfig:
iteration.
iteration.
max_model_len: Maximum length of a sequence (including prompt
max_model_len: Maximum length of a sequence (including prompt
and generated text).
and generated text).
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
num_lookahead_slots: The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
accepted.
delay_factor: Apply a delay (of delay factor multiplied by previous
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
enable_chunked_prefill: If True, prefill requests can be chunked based
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
on the remaining max_num_batched_tokens.
"""
"""
...
@@ -543,6 +547,7 @@ class SchedulerConfig:
...
@@ -543,6 +547,7 @@ class SchedulerConfig:
max_num_seqs
:
int
,
max_num_seqs
:
int
,
max_model_len
:
int
,
max_model_len
:
int
,
use_v2_block_manager
:
bool
=
False
,
use_v2_block_manager
:
bool
=
False
,
num_lookahead_slots
:
int
=
0
,
delay_factor
:
float
=
0.0
,
delay_factor
:
float
=
0.0
,
enable_chunked_prefill
:
bool
=
False
,
enable_chunked_prefill
:
bool
=
False
,
)
->
None
:
)
->
None
:
...
@@ -554,9 +559,11 @@ class SchedulerConfig:
...
@@ -554,9 +559,11 @@ class SchedulerConfig:
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_batched_tokens
=
max
(
max_model_len
,
2048
)
self
.
max_num_seqs
=
max_num_seqs
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
self
.
delay_factor
=
delay_factor
self
.
use_v2_block_manager
=
use_v2_block_manager
self
.
use_v2_block_manager
=
use_v2_block_manager
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
delay_factor
=
delay_factor
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
chunked_prefill_enabled
=
enable_chunked_prefill
self
.
_verify_args
()
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
...
@@ -568,12 +575,19 @@ class SchedulerConfig:
...
@@ -568,12 +575,19 @@ class SchedulerConfig:
"max_num_batched_tokens and makes vLLM reject longer "
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len."
)
"decrease max_model_len."
)
if
self
.
max_num_batched_tokens
<
self
.
max_num_seqs
:
if
self
.
max_num_batched_tokens
<
self
.
max_num_seqs
:
raise
ValueError
(
raise
ValueError
(
f
"max_num_batched_tokens (
{
self
.
max_num_batched_tokens
}
) must "
f
"max_num_batched_tokens (
{
self
.
max_num_batched_tokens
}
) must "
"be greater than or equal to max_num_seqs "
"be greater than or equal to max_num_seqs "
f
"(
{
self
.
max_num_seqs
}
)."
)
f
"(
{
self
.
max_num_seqs
}
)."
)
if
self
.
num_lookahead_slots
<
0
:
raise
ValueError
(
"num_lookahead_slots "
f
"(
{
self
.
num_lookahead_slots
}
) must be greater than or "
"equal to 0."
)
class
DeviceConfig
:
class
DeviceConfig
:
...
...
vllm/core/block/block_table.py
View file @
93deb0b3
...
@@ -85,7 +85,9 @@ class BlockTable:
...
@@ -85,7 +85,9 @@ class BlockTable:
device
=
device
)
device
=
device
)
self
.
_num_full_slots
=
len
(
token_ids
)
self
.
_num_full_slots
=
len
(
token_ids
)
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
def
append_token_ids
(
self
,
token_ids
:
List
[
int
],
num_lookahead_slots
:
int
=
0
)
->
None
:
"""Appends a sequence of token IDs to the existing blocks in the
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
BlockTable.
...
@@ -102,14 +104,13 @@ class BlockTable:
...
@@ -102,14 +104,13 @@ class BlockTable:
token_ids (List[int]): The sequence of token IDs to be appended.
token_ids (List[int]): The sequence of token IDs to be appended.
"""
"""
assert
self
.
_is_allocated
assert
self
.
_is_allocated
assert
token_ids
,
"can't append empty token ids"
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
))
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
num_lookahead_slots
)
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
first_chunk_size
=
self
.
_block_size
-
(
self
.
_num_full_slots
%
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
self
.
_block_size
)
token_blocks
=
[
token_ids
[:
first_chunk_size
]]
+
chunk_list
(
token_ids
[
first_chunk_size
:],
self
.
_block_size
)
for
block
,
token_block
in
zip
(
blocks
,
token_blocks
):
for
block
,
token_block
in
zip
(
blocks
,
token_blocks
):
block
.
append_token_ids
(
token_block
)
block
.
append_token_ids
(
token_block
)
...
@@ -195,6 +196,25 @@ class BlockTable:
...
@@ -195,6 +196,25 @@ class BlockTable:
assert
self
.
_is_allocated
assert
self
.
_is_allocated
return
[
block
.
block_id
for
block
in
self
.
_blocks
]
return
[
block
.
block_id
for
block
in
self
.
_blocks
]
def
get_unseen_token_ids
(
self
,
sequence_token_ids
:
List
[
int
])
->
List
[
int
]:
"""Get the number of "unseen" tokens in the sequence.
Unseen tokens are tokens in the sequence corresponding to this block
table, but are not yet appended to this block table.
Args:
sequence_token_ids (List[int]): The list of token ids in the
sequence.
Returns:
List[int]: The postfix of sequence_token_ids that has not yet been
appended to the block table.
"""
# Since the block table is append-only, the unseen token ids are the
# ones after the appended ones.
return
sequence_token_ids
[
self
.
num_full_slots
:]
def
_allocate_blocks_for_token_ids
(
self
,
prev_block
:
Optional
[
Block
],
def
_allocate_blocks_for_token_ids
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
device
:
Device
)
->
List
[
Block
]:
device
:
Device
)
->
List
[
Block
]:
...
@@ -243,3 +263,29 @@ class BlockTable:
...
@@ -243,3 +263,29 @@ class BlockTable:
int: The total number of tokens currently stored in the BlockTable.
int: The total number of tokens currently stored in the BlockTable.
"""
"""
return
self
.
_num_full_slots
return
self
.
_num_full_slots
def
get_num_blocks_touched_by_append_slots
(
self
,
token_ids
:
List
[
int
],
num_lookahead_slots
:
int
)
->
int
:
"""Determine how many blocks will be "touched" by appending the token
ids.
This is required for the scheduler to determine whether a sequence can
continue generation, or if it must be preempted.
"""
all_token_ids
=
token_ids
+
[
-
1
]
*
num_lookahead_slots
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
all_token_ids
)
return
len
(
token_blocks
)
def
_chunk_token_blocks_for_append
(
self
,
token_ids
:
List
[
int
])
->
List
[
List
[
int
]]:
"""Split the token ids into block-sized chunks so they can be easily
appended to blocks. The first such "token block" may have less token ids
than the block size, since the last allocated block may be partially
full.
"""
first_chunk_size
=
self
.
_block_size
-
(
self
.
_num_full_slots
%
self
.
_block_size
)
token_blocks
=
[
token_ids
[:
first_chunk_size
]]
+
chunk_list
(
token_ids
[
first_chunk_size
:],
self
.
_block_size
)
return
token_blocks
vllm/core/block_manager_v1.py
View file @
93deb0b3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
itertools
import
count
,
takewhile
from
itertools
import
count
,
takewhile
from
os.path
import
commonprefix
from
os.path
import
commonprefix
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
from
vllm.block
import
BlockTable
,
PhysicalTokenBlock
from
vllm.block
import
BlockTable
,
PhysicalTokenBlock
from
vllm.core.evictor
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.core.evictor
import
EvictionPolicy
,
Evictor
,
make_evictor
...
@@ -292,7 +292,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -292,7 +292,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
copy
()
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
bool
:
assert
(
num_lookahead_slots
==
0
),
"lookahead allocation not supported in BlockSpaceManagerV1"
# Simple heuristic: If there is at least one free block
# Simple heuristic: If there is at least one free block
# for each sequence, we can append.
# for each sequence, we can append.
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
num_free_gpu_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
...
@@ -364,10 +369,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -364,10 +369,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
assert
new_block
.
ref_count
==
1
assert
new_block
.
ref_count
==
1
return
new_block
return
new_block
def
append_slot
(
def
append_slot
s
(
self
,
self
,
seq
:
Sequence
,
seq
:
Sequence
,
)
->
Optional
[
Tuple
[
int
,
int
]]:
num_lookahead_slots
:
int
=
0
,
)
->
Dict
[
int
,
List
[
int
]]:
"""Allocate a physical slot for a new token."""
"""Allocate a physical slot for a new token."""
logical_blocks
=
seq
.
logical_token_blocks
logical_blocks
=
seq
.
logical_token_blocks
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -386,7 +392,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -386,7 +392,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
# Allocate a new physical block.
# Allocate a new physical block.
new_block
=
self
.
_allocate_last_physical_block
(
seq
)
new_block
=
self
.
_allocate_last_physical_block
(
seq
)
block_table
.
append
(
new_block
)
block_table
.
append
(
new_block
)
return
None
return
{}
# We want to append the token to the last physical block.
# We want to append the token to the last physical block.
last_block
=
block_table
[
-
1
]
last_block
=
block_table
[
-
1
]
...
@@ -399,7 +405,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -399,7 +405,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
maybe_new_block
=
self
.
_maybe_promote_last_block
(
maybe_new_block
=
self
.
_maybe_promote_last_block
(
seq
,
last_block
)
seq
,
last_block
)
block_table
[
-
1
]
=
maybe_new_block
block_table
[
-
1
]
=
maybe_new_block
return
None
return
{}
else
:
else
:
# The last block is shared with other sequences.
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
# Copy on Write: Allocate a new block and copy the tokens.
...
@@ -407,7 +413,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -407,7 +413,7 @@ class BlockSpaceManagerV1(BlockSpaceManager):
block_table
[
-
1
]
=
new_block
block_table
[
-
1
]
=
new_block
self
.
gpu_allocator
.
free
(
last_block
)
self
.
gpu_allocator
.
free
(
last_block
)
return
last_block
.
block_number
,
new_block
.
block_number
return
{
last_block
.
block_number
:
[
new_block
.
block_number
]}
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
# NOTE: fork does not allocate a new physical block.
# NOTE: fork does not allocate a new physical block.
...
@@ -433,7 +439,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -433,7 +439,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
blocks
.
update
(
self
.
block_tables
[
seq
.
seq_id
])
blocks
.
update
(
self
.
block_tables
[
seq
.
seq_id
])
return
list
(
blocks
)
return
list
(
blocks
)
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
bool
:
assert
(
num_lookahead_slots
==
0
),
"BlockSpaceManagerV1 does not support lookahead allocation"
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
num_swapped_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
num_swapped_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
num_free_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
num_free_blocks
=
self
.
gpu_allocator
.
get_num_free_blocks
()
...
@@ -443,7 +453,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -443,7 +453,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
num_required_blocks
=
len
(
blocks
)
+
num_swapped_seqs
num_required_blocks
=
len
(
blocks
)
+
num_swapped_seqs
return
num_free_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
return
num_free_blocks
-
num_required_blocks
>=
self
.
watermark_blocks
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
Dict
[
int
,
int
]:
assert
(
num_lookahead_slots
==
0
),
"BlockSpaceManagerV1 does not support lookahead allocation"
# CPU block -> GPU block.
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
...
...
vllm/core/block_manager_v2.py
View file @
93deb0b3
"""A block manager that manages token blocks."""
"""A block manager that manages token blocks."""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
...
@@ -21,6 +21,24 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -21,6 +21,24 @@ class BlockSpaceManagerV2(BlockSpaceManager):
sliding-window are not feature complete. This class implements the design
sliding-window are not feature complete. This class implements the design
described in https://github.com/vllm-project/vllm/pull/3492.
described in https://github.com/vllm-project/vllm/pull/3492.
Lookahead slots
The block manager has the notion of a "lookahead slot". These are slots
in the KV cache that are allocated for a sequence. Unlike the other
allocated slots, the content of these slots is undefined -- the worker
may use the memory allocations in any way.
In practice, a worker could use these lookahead slots to run multiple
forward passes for a single scheduler invocation. Each successive
forward pass would write KV activations to the corresponding lookahead
slot. This allows low inter-token latency use-cases, where the overhead
of continuous batching scheduling is amortized over >1 generated tokens.
Speculative decoding uses lookahead slots to store KV activations of
proposal tokens.
See https://github.com/vllm-project/vllm/pull/3250 for more information
on lookahead scheduling.
Args:
Args:
block_size (int): The size of each memory block.
block_size (int): The size of each memory block.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
...
@@ -116,35 +134,51 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -116,35 +134,51 @@ class BlockSpaceManagerV2(BlockSpaceManager):
for
seq
in
waiting_seqs
[
1
:]:
for
seq
in
waiting_seqs
[
1
:]:
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
# Simple heuristic: If there is at least one free block
num_lookahead_slots
:
int
)
->
bool
:
# for each sequence, we can append.
"""Determine if there is enough space in the GPU KV cache to continue
generation of the specified sequence group.
We use a worst-case heuristic: assume each touched block will require a
new allocation (either via CoW or new block). We can append slots if the
number of touched blocks is less than the number of free blocks.
"Lookahead slots" are slots that are allocated in addition to the slots
for known tokens. The contents of the lookahead slots are not defined.
This is used by speculative decoding when speculating future tokens.
"""
num_touched_blocks
=
0
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
num_touched_blocks
+=
(
block_table
.
get_num_blocks_touched_by_append_slots
(
token_ids
=
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
num_lookahead_slots
=
num_lookahead_slots
,
))
num_free_gpu_blocks
=
self
.
block_allocator
.
get_num_free_blocks
(
num_free_gpu_blocks
=
self
.
block_allocator
.
get_num_free_blocks
(
Device
.
GPU
)
Device
.
GPU
)
num_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
return
num_touched_blocks
<=
num_free_gpu_blocks
return
num_seqs
<=
num_free_gpu_blocks
def
append_slot
(
def
append_slot
s
(
self
,
self
,
seq
:
Sequence
,
seq
:
Sequence
,
)
->
Optional
[
Tuple
[
int
,
int
]]:
num_lookahead_slots
:
int
,
)
->
Dict
[
int
,
List
[
int
]]:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
# Get unseen token ids.
block_table
.
append_token_ids
(
num_full_slots
=
block_table
.
num_full_slots
token_ids
=
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
unseen_token_ids
=
seq
.
get_token_ids
()[
num_full_slots
:]
num_lookahead_slots
=
num_lookahead_slots
,
assert
unseen_token_ids
)
block_table
.
append_token_ids
(
unseen_token_ids
)
# Return any copy-on-writes.
_
=
self
.
block_allocator
.
clear_copy_on_writes
()
# TODO extend append_slot interface to append_slots
# @cadedaniel will do in https://github.com/vllm-project/vllm/pull/3250
return
None
# Return any new copy-on-writes.
new_cows
=
self
.
block_allocator
.
clear_copy_on_writes
()
return
new_cows
def
free
(
self
,
seq
:
Sequence
)
->
None
:
def
free
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
...
@@ -191,10 +225,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -191,10 +225,12 @@ class BlockSpaceManagerV2(BlockSpaceManager):
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
return
False
return
False
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
raise
NotImplementedError
raise
NotImplementedError
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
...
...
vllm/core/interfaces.py
View file @
93deb0b3
import
enum
import
enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
from
vllm.sequence
import
Sequence
,
SequenceGroup
from
vllm.sequence
import
Sequence
,
SequenceGroup
...
@@ -44,14 +44,16 @@ class BlockSpaceManager(ABC):
...
@@ -44,14 +44,16 @@ class BlockSpaceManager(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
can_append_slot
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_append_slots
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
pass
pass
@
abstractmethod
@
abstractmethod
def
append_slot
(
def
append_slot
s
(
self
,
self
,
seq
:
Sequence
,
seq
:
Sequence
,
)
->
Optional
[
Tuple
[
int
,
int
]]:
num_lookahead_slots
:
int
,
)
->
Dict
[
int
,
List
[
int
]]:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -59,11 +61,13 @@ class BlockSpaceManager(ABC):
...
@@ -59,11 +61,13 @@ class BlockSpaceManager(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
bool
:
pass
pass
@
abstractmethod
@
abstractmethod
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
pass
pass
@
abstractmethod
@
abstractmethod
...
...
vllm/core/scheduler.py
View file @
93deb0b3
...
@@ -52,6 +52,7 @@ class SchedulerOutputs:
...
@@ -52,6 +52,7 @@ class SchedulerOutputs:
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
ignored_seq_groups
:
List
[
SequenceGroup
],
ignored_seq_groups
:
List
[
SequenceGroup
],
num_lookahead_slots
:
int
,
)
->
None
:
)
->
None
:
"""A list of sequence groups to be scheduled as a single batch.
"""A list of sequence groups to be scheduled as a single batch.
...
@@ -86,6 +87,7 @@ class SchedulerOutputs:
...
@@ -86,6 +87,7 @@ class SchedulerOutputs:
# Swap in and swap out should never happen at the same time.
# Swap in and swap out should never happen at the same time.
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
assert
not
(
blocks_to_swap_in
and
blocks_to_swap_out
)
self
.
num_lookahead_slots
=
num_lookahead_slots
self
.
num_loras
:
int
=
len
(
self
.
lora_requests
)
self
.
num_loras
:
int
=
len
(
self
.
lora_requests
)
if
self
.
num_loras
>
0
:
if
self
.
num_loras
>
0
:
...
@@ -309,6 +311,8 @@ class Scheduler:
...
@@ -309,6 +311,8 @@ class Scheduler:
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
ignored_seq_groups
,
ignored_seq_groups
=
ignored_seq_groups
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
True
),
)
)
return
scheduler_outputs
return
scheduler_outputs
...
@@ -323,7 +327,7 @@ class Scheduler:
...
@@ -323,7 +327,7 @@ class Scheduler:
preempted
:
List
[
SequenceGroup
]
=
[]
preempted
:
List
[
SequenceGroup
]
=
[]
while
self
.
running
:
while
self
.
running
:
seq_group
=
self
.
running
.
popleft
()
seq_group
=
self
.
running
.
popleft
()
while
not
self
.
block_manager
.
can_append_slot
(
seq_group
):
while
not
self
.
_
can_append_slot
s
(
seq_group
):
if
self
.
running
:
if
self
.
running
:
# Preempt the lowest-priority sequence groups.
# Preempt the lowest-priority sequence groups.
victim_seq_group
=
self
.
running
.
pop
()
victim_seq_group
=
self
.
running
.
pop
()
...
@@ -337,7 +341,7 @@ class Scheduler:
...
@@ -337,7 +341,7 @@ class Scheduler:
break
break
else
:
else
:
# Append new slots to the sequence group.
# Append new slots to the sequence group.
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
self
.
_append_slot
s
(
seq_group
,
blocks_to_copy
)
running
.
append
(
seq_group
)
running
.
append
(
seq_group
)
self
.
running
=
running
self
.
running
=
running
...
@@ -366,7 +370,7 @@ class Scheduler:
...
@@ -366,7 +370,7 @@ class Scheduler:
continue
continue
# If the sequence group cannot be swapped in, stop.
# If the sequence group cannot be swapped in, stop.
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
if
not
self
.
_
can_swap_in
(
seq_group
):
break
break
# The total number of sequences in the RUNNING state should not
# The total number of sequences in the RUNNING state should not
...
@@ -380,7 +384,7 @@ class Scheduler:
...
@@ -380,7 +384,7 @@ class Scheduler:
curr_loras
.
add
(
lora_int_id
)
curr_loras
.
add
(
lora_int_id
)
self
.
swapped
.
popleft
()
self
.
swapped
.
popleft
()
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
self
.
_append_slot
s
(
seq_group
,
blocks_to_copy
)
num_curr_seqs
+=
num_new_seqs
num_curr_seqs
+=
num_new_seqs
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
...
@@ -405,9 +409,32 @@ class Scheduler:
...
@@ -405,9 +409,32 @@ class Scheduler:
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
blocks_to_copy
=
blocks_to_copy
,
ignored_seq_groups
=
[],
ignored_seq_groups
=
[],
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
),
)
)
return
scheduler_outputs
return
scheduler_outputs
def
_can_append_slots
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
"""Determine whether or not we have enough space in the KV cache to
continue generation of the sequence group.
"""
# Appending slots only occurs in decoding.
is_prefill
=
False
return
self
.
block_manager
.
can_append_slots
(
seq_group
=
seq_group
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
)
def
_can_swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
# Swapping in is considered decode.
is_prefill
=
False
return
self
.
block_manager
.
can_swap_in
(
seq_group
=
seq_group
,
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
),
)
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
def
schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
]:
# Schedule sequence groups.
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# This function call changes the internal states of the scheduler
...
@@ -482,19 +509,30 @@ class Scheduler:
...
@@ -482,19 +509,30 @@ class Scheduler:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
WAITING
):
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
status
=
SequenceStatus
.
RUNNING
def
_append_slot
(
def
_append_slot
s
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
)
->
None
:
"""Appends new slots to the sequences in the given sequence group.
Args:
seq_group (SequenceGroup): The sequence group containing the
sequences to append slots to.
blocks_to_copy (Dict[int, List[int]]): A dictionary mapping source
block indices to lists of destination block indices. This
dictionary is updated with the new source and destination block
indices for the appended slots.
"""
num_lookahead_slots
=
self
.
_get_num_lookahead_slots
(
is_prefill
=
False
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
ret
=
self
.
block_manager
.
append_slot
(
seq
)
cows
=
self
.
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
)
if
ret
is
not
None
:
src_block
,
dst_block
=
ret
for
src
,
dests
in
cows
.
items
():
if
src_block
in
blocks_to_copy
:
if
src
not
in
blocks_to_copy
:
blocks_to_copy
[
src_block
].
append
(
dst_block
)
blocks_to_copy
[
src
]
=
[]
else
:
blocks_to_copy
[
src
].
extend
(
dests
)
blocks_to_copy
[
src_block
]
=
[
dst_block
]
def
_preempt
(
def
_preempt
(
self
,
self
,
...
@@ -588,3 +626,16 @@ class Scheduler:
...
@@ -588,3 +626,16 @@ class Scheduler:
else
:
else
:
passed_delay
=
True
passed_delay
=
True
return
passed_delay
return
passed_delay
def
_get_num_lookahead_slots
(
self
,
is_prefill
:
bool
)
->
int
:
"""The number of slots to allocate per sequence per step, beyond known
token ids. Speculative decoding uses these slots to store KV activations
of tokens which may or may not be accepted.
Speculative decoding does not yet support prefill, so we do not perform
lookahead allocation for prefill.
"""
if
is_prefill
:
return
0
return
self
.
scheduler_config
.
num_lookahead_slots
vllm/engine/arg_utils.py
View file @
93deb0b3
...
@@ -53,8 +53,8 @@ class EngineArgs:
...
@@ -53,8 +53,8 @@ class EngineArgs:
max_cpu_loras
:
Optional
[
int
]
=
None
max_cpu_loras
:
Optional
[
int
]
=
None
device
:
str
=
'auto'
device
:
str
=
'auto'
ray_workers_use_nsight
:
bool
=
False
ray_workers_use_nsight
:
bool
=
False
forced_num_gpu_blocks
:
Optional
[
int
]
=
None
forced_num_gpu_blocks
:
Optional
[
int
]
=
None
num_lookahead_slots
:
int
=
0
# Related to Vision-language models such as llava
# Related to Vision-language models such as llava
image_input_type
:
Optional
[
str
]
=
None
image_input_type
:
Optional
[
str
]
=
None
...
@@ -202,6 +202,14 @@ class EngineArgs:
...
@@ -202,6 +202,14 @@ class EngineArgs:
parser
.
add_argument
(
'--use-v2-block-manager'
,
parser
.
add_argument
(
'--use-v2-block-manager'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'Use BlockSpaceMangerV2'
)
help
=
'Use BlockSpaceMangerV2'
)
parser
.
add_argument
(
'--num-lookahead-slots'
,
type
=
int
,
default
=
EngineArgs
.
num_lookahead_slots
,
help
=
'Experimental scheduling config necessary for '
'speculative decoding. This will be replaced by '
'speculative config in the future; it is present '
'to enable correctness tests until then.'
)
parser
.
add_argument
(
'--seed'
,
parser
.
add_argument
(
'--seed'
,
type
=
int
,
type
=
int
,
...
@@ -406,6 +414,7 @@ class EngineArgs:
...
@@ -406,6 +414,7 @@ class EngineArgs:
self
.
max_num_seqs
,
self
.
max_num_seqs
,
model_config
.
max_model_len
,
model_config
.
max_model_len
,
self
.
use_v2_block_manager
,
self
.
use_v2_block_manager
,
num_lookahead_slots
=
self
.
num_lookahead_slots
,
delay_factor
=
self
.
scheduler_delay_factor
,
delay_factor
=
self
.
scheduler_delay_factor
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
enable_chunked_prefill
=
self
.
enable_chunked_prefill
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment