Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5bf8789b
Unverified
Commit
5bf8789b
authored
Sep 28, 2024
by
sroy745
Committed by
GitHub
Sep 29, 2024
Browse files
[Bugfix] Block manager v2 with preemption and lookahead slots (#8824)
parent
d1537039
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
133 additions
and
116 deletions
+133
-116
tests/basic_correctness/test_preemption.py
tests/basic_correctness/test_preemption.py
+7
-2
tests/core/block/test_block_manager_v2.py
tests/core/block/test_block_manager_v2.py
+46
-1
tests/core/block/test_naive_block.py
tests/core/block/test_naive_block.py
+10
-9
tests/core/block/test_prefix_caching_block.py
tests/core/block/test_prefix_caching_block.py
+13
-12
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+7
-10
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+3
-7
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+10
-25
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+15
-26
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+22
-24
No files found.
tests/basic_correctness/test_preemption.py
View file @
5bf8789b
...
...
@@ -23,8 +23,10 @@ MODELS = [
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
def
check_settings
():
assert
ENABLE_ARTIFICIAL_PREEMPT
is
True
,
(
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest "
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest "
"tests/basic_correctness/test_preemption.py`"
)
...
...
@@ -199,6 +201,7 @@ def test_swap(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"float"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
96
])
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"use_v2_block_manager"
,
[
True
,
False
])
def
test_swap_infeasible
(
vllm_runner
,
example_prompts
,
...
...
@@ -207,6 +210,7 @@ def test_swap_infeasible(
max_tokens
:
int
,
beam_width
:
int
,
worker_use_ray
:
bool
,
use_v2_block_manager
:
bool
,
)
->
None
:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE
=
16
...
...
@@ -223,6 +227,7 @@ def test_swap_infeasible(
num_gpu_blocks_override
=
prefill_blocks
+
decode_blocks
,
max_model_len
=
(
prefill_blocks
+
decode_blocks
)
*
BLOCK_SIZE
,
worker_use_ray
=
worker_use_ray
,
use_v2_block_manager
=
use_v2_block_manager
,
)
as
vllm_model
:
sampling_params
=
SamplingParams
(
n
=
beam_width
,
use_beam_search
=
True
,
...
...
tests/core/block/test_block_manager_v2.py
View file @
5bf8789b
...
...
@@ -373,6 +373,52 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots,
seq_group
,
num_lookahead_slots
)
==
AllocStatus
.
NEVER
@
pytest
.
mark
.
parametrize
(
"num_lookahead_slots"
,
[
0
,
2
,
10
])
@
pytest
.
mark
.
parametrize
(
"enable_caching"
,
[
False
,
True
])
def
test_swap_in_infeasible
(
num_lookahead_slots
,
enable_caching
):
"""Verifies that swapping fails if there is not enough free blocks
to account for unseen tokens and lookahead_slots.
"""
block_size
=
8
num_cpu_blocks
=
1
num_gpu_blocks
=
1
block_manager
=
BlockSpaceManagerV2
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
,
enable_caching
=
enable_caching
)
prompt_length
=
block_size
-
3
assert
prompt_length
>
0
prompt
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
prompt_length
)
prompt
.
status
=
SequenceStatus
.
WAITING
block_manager
.
allocate
(
seq_group
)
# Emulate a forward pass by appending a single token.
# The block manager then knows how many unprocessed
# tokens will be written in the next forward pass.
token_id
=
0
prompt
.
status
=
SequenceStatus
.
RUNNING
prompt
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
# Swap seq group from GPU -> CPU.
assert
block_manager
.
can_swap_out
(
seq_group
)
block_manager
.
swap_out
(
seq_group
)
prompt
.
status
=
SequenceStatus
.
SWAPPED
# Swap seq group from CPU -> GPU.
# The number of unseen tokens is 1. If the number of existing
# tokens plus the unseen ones and number of lookahead slots exceeds
# the total number of available GPU blocks then the swap
# should fail.
num_unseen_tokens
=
1
if
(
num_lookahead_slots
+
num_unseen_tokens
+
prompt_length
)
<=
(
block_size
*
num_gpu_blocks
):
assert
block_manager
.
can_swap_in
(
seq_group
,
num_lookahead_slots
)
==
AllocStatus
.
OK
else
:
assert
block_manager
.
can_swap_in
(
seq_group
,
num_lookahead_slots
)
==
AllocStatus
.
NEVER
# TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level.
...
...
@@ -400,7 +446,6 @@ def test_sliding_window(block_size, prompt_len, num_slots_to_append,
if
max_n
is
None
:
max_n
=
min_n
used
=
num_gpu_blocks
-
block_manager
.
get_num_free_gpu_blocks
()
#print("check", min_n, used, max_n)
assert
min_n
<=
used
assert
used
<=
max_n
...
...
tests/core/block/test_naive_block.py
View file @
5bf8789b
...
...
@@ -104,9 +104,9 @@ class TestNaiveBlockAllocator:
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
8
])
def
test_naive_block_get_num_blocks_touched
(
num_blocks
,
block_size
):
def
test_naive_block_get_num_
full_
blocks_touched
(
num_blocks
,
block_size
):
""" Verify the allocator can correctly return the number of
blocks touched
, with different lookahead slots
.
full
blocks touched.
"""
allocator_src
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
num_blocks
=
num_blocks
,
...
...
@@ -124,7 +124,7 @@ class TestNaiveBlockAllocator:
src_blocks
=
[
allocate_block
()
for
_
in
range
(
num_blocks
-
1
)]
# All blocks are cached
assert
allocator_dst
.
get_num_blocks_touched
(
assert
allocator_dst
.
get_num_
full_
blocks_touched
(
src_blocks
)
==
num_blocks
-
1
# Insert one non-full block in the src
...
...
@@ -136,9 +136,10 @@ class TestNaiveBlockAllocator:
src_blocks
.
append
(
allocate_non_full_block
())
src_blocks
[
-
1
].
append_token_ids
([
0
])
assert
allocator_dst
.
get_num_blocks_touched
(
src_blocks
,
num_lookahead_slots
=
1
)
==
num_blocks
assert
allocator_dst
.
get_num_blocks_touched
(
src_blocks
,
num_lookahead_slots
=
block_size
-
1
)
==
num_blocks
assert
allocator_dst
.
get_num_blocks_touched
(
src_blocks
,
num_lookahead_slots
=
block_size
)
==
(
num_blocks
+
1
)
assert
allocator_dst
.
get_num_full_blocks_touched
(
src_blocks
)
==
num_blocks
-
1
# Fill up the last source block and then invoke
# get_num_blocks_touched
src_blocks
[
-
1
].
append_token_ids
([
0
]
*
(
block_size
-
1
))
assert
allocator_dst
.
get_num_full_blocks_touched
(
src_blocks
)
==
num_blocks
tests/core/block/test_prefix_caching_block.py
View file @
5bf8789b
...
...
@@ -318,11 +318,10 @@ class TestPrefixCachingBlockAllocator:
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
8
])
def
test_prefix_caching_block_get_num_blocks_touched
(
def
test_prefix_caching_block_get_num_
full_
blocks_touched
(
num_blocks
,
block_size
):
""" Verify the allocator can correctly return the number of
blocks touched, when there are cached prefixes and different
lookahead slots.
blocks touched, when there are cached prefixes.
"""
allocator_src
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
...
...
@@ -346,28 +345,30 @@ class TestPrefixCachingBlockAllocator:
token_ids
=
token_ids
,
allocator
=
allocator_src
,
)
# All blocks are cached
assert
allocator_dst
.
get_num_blocks_touched
(
blocks_to_swap_in
)
==
0
assert
allocator_dst
.
get_num_full_blocks_touched
(
blocks_to_swap_in
)
==
0
# Free the first block in the dst
allocator_dst
.
free
(
cached_blocks
[
0
])
# Now the first block becomes dangling, the swapped blocks need
# to reclaim the first block in the dst
assert
allocator_dst
.
get_num_blocks_touched
(
blocks_to_swap_in
)
==
1
assert
allocator_dst
.
get_num_full_blocks_touched
(
blocks_to_swap_in
)
==
1
# Insert one non-full block in the src
non_full_block
=
allocator_src
.
allocate_mutable_block
(
blocks_to_swap_in
[
-
1
])
non_full_block
.
append_token_ids
([
0
])
blocks_to_swap_in
.
append
(
non_full_block
)
assert
allocator_dst
.
get_num_blocks_touched
(
blocks_to_swap_in
,
num_lookahead_slots
=
1
)
==
2
assert
allocator_dst
.
get_num_blocks_touched
(
blocks_to_swap_in
,
num_lookahead_slots
=
block_size
-
1
)
==
2
assert
allocator_dst
.
get_num_blocks_touched
(
blocks_to_swap_in
,
num_lookahead_slots
=
block_size
)
==
3
assert
allocator_dst
.
get_num_full_blocks_touched
(
blocks_to_swap_in
)
==
1
# Fill up the last mutable block and invoke get_num_blocks_touched.
# Note: The last block is not cached so it will be touched.
non_full_block
.
append_token_ids
([
0
]
*
(
block_size
-
1
))
assert
allocator_dst
.
get_num_full_blocks_touched
(
blocks_to_swap_in
)
==
2
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
...
...
vllm/core/block/cpu_gpu_block_allocator.py
View file @
5bf8789b
...
...
@@ -259,25 +259,22 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
current_swap_mapping
[
src_block_id
]
=
dst_block_id
return
current_swap_mapping
def
get_num_blocks_touched
(
self
,
blocks
:
List
[
Block
],
device
:
Device
,
num_lookahead_slots
:
int
=
0
)
->
int
:
"""Returns the number of blocks that will be touched by
def
get_num_full_blocks_touched
(
self
,
blocks
:
List
[
Block
],
device
:
Device
)
->
int
:
"""Returns the number of full blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
Args:
blocks: List of blocks to be swapped.
device (Device): Device to swap the 'blocks' on.
num_lookahead_slots (int): Number of lookahead slots used in
speculative decoding, default to 0.
Returns:
int: the number of blocks that will be touched by
int: the number of
full
blocks that will be touched by
swapping in/out the given blocks on to the 'device'.
Non full blocks are ignored when deciding the number
of blocks to touch.
"""
return
self
.
_allocators
[
device
].
get_num_blocks_touched
(
blocks
,
num_lookahead_slots
)
return
self
.
_allocators
[
device
].
get_num_full_blocks_touched
(
blocks
)
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
int
,
int
]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
...
...
vllm/core/block/interfaces.py
View file @
5bf8789b
...
...
@@ -181,9 +181,7 @@ class BlockAllocator(ABC):
pass
@
abstractmethod
def
get_num_blocks_touched
(
self
,
blocks
:
List
[
Block
],
num_lookahead_slots
:
int
=
0
)
->
int
:
def
get_num_full_blocks_touched
(
self
,
blocks
:
List
[
Block
])
->
int
:
pass
@
abstractmethod
...
...
@@ -260,10 +258,8 @@ class DeviceAwareBlockAllocator(ABC):
pass
@
abstractmethod
def
get_num_blocks_touched
(
self
,
blocks
:
List
[
Block
],
device
:
Device
,
num_lookahead_slots
:
int
=
0
)
->
int
:
def
get_num_full_blocks_touched
(
self
,
blocks
:
List
[
Block
],
device
:
Device
)
->
int
:
pass
@
abstractmethod
...
...
vllm/core/block/naive_block.py
View file @
5bf8789b
...
...
@@ -4,7 +4,6 @@ from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple
from
vllm.core.block.common
import
(
BlockPool
,
CopyOnWriteTracker
,
RefCounter
,
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.utils
import
cdiv
Refcount
=
int
...
...
@@ -282,40 +281,26 @@ class NaiveBlockAllocator(BlockAllocator):
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
(
"There is no promotion for naive blocks"
)
def
get_num_blocks_touched
(
self
,
blocks
:
List
[
Block
],
num_lookahead_slots
:
int
=
0
)
->
int
:
"""Determine the number of blocks that will be touched by
swapping in/out the given blocks from certain sequence
group with the provided num_lookahead_slots.
def
get_num_full_blocks_touched
(
self
,
blocks
:
List
[
Block
])
->
int
:
"""Returns the number of full blocks that will be touched by
swapping in/out.
Args:
blocks (List[Block]): The potential blocks to swap.
num_lookahead_slots (int): number of lookahead slots (0 for swap
out).
blocks: List of blocks to be swapped.
Returns:
int: the number of blocks that will be touched by
swapping in/out the given blocks and num_lookahead_slots.
int: the number of full blocks that will be touched by
swapping in/out the given blocks. Non full blocks are ignored
when deciding the number of blocks to touch.
"""
# NOTE: for naive block, we use set to eliminate common blocks among
# seqs, also we compare the empty slots in the mutable blocks with
# lookahead slots to get the number of unique new block that are
# needed.
old_block_set
=
set
()
new_block_count
=
0
# TODO(cade): make sure the logic is correct and clean it up.
for
block
in
blocks
:
if
not
block
.
is_full
and
num_lookahead_slots
!=
0
:
new_block_count
+=
1
if
num_lookahead_slots
>
block
.
num_empty_slots
:
new_block_count
+=
cdiv
(
num_lookahead_slots
-
block
.
num_empty_slots
,
self
.
_block_size
)
else
:
old_block_set
.
add
(
block
.
block_id
)
num_touched_blocks
=
new_block_count
+
len
(
old_block_set
)
return
num_touched_blocks
if
block
.
is_full
:
old_block_set
.
add
(
block
)
return
len
(
old_block_set
)
def
swap_out
(
self
,
blocks
:
List
[
Block
])
->
None
:
for
block
in
blocks
:
...
...
vllm/core/block/prefix_caching_block.py
View file @
5bf8789b
...
...
@@ -8,7 +8,6 @@ from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from
vllm.core.block.naive_block
import
(
BlockPool
,
NaiveBlock
,
NaiveBlockAllocator
)
from
vllm.core.evictor_v2
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.utils
import
cdiv
PrefixHash
=
int
...
...
@@ -576,37 +575,27 @@ class PrefixCachingBlockAllocator(BlockAllocator):
if
ids
])
def
get_num_blocks_touched
(
self
,
blocks
:
List
[
Block
],
num_lookahead_slots
:
int
=
0
)
->
int
:
"""Determine the number of blocks that will be touched by
swapping in/out the given blocks from certain sequence
group with the provided num_lookahead_slots.
def
get_num_full_blocks_touched
(
self
,
blocks
:
List
[
Block
])
->
int
:
"""Returns the number of full blocks that will be touched by
swapping in/out.
Args:
blocks (List[Block]): The potential blocks to swap.
num_lookahead_slots (int): number of lookahead slots (0 for
swap out).
blocks: List of blocks to be swapped.
Returns:
int: the number of blocks that will be touched by
swapping in/out the given blocks and num_lookahead_slots.
int: the number of full blocks that will be touched by
swapping in/out the given blocks. Non full blocks are ignored
when deciding the number of blocks to touch.
"""
num_touched_blocks
=
0
num_touched_blocks
:
int
=
0
for
block
in
blocks
:
if
not
block
.
is_full
:
# If the block has a match in the cache and the cached
# block is not referenced, then we still count it as a
# touched block
if
block
.
is_full
and
(
not
self
.
is_block_cached
(
block
)
or
\
(
block
.
content_hash
is
not
None
and
\
self
.
_cached_blocks
[
block
.
content_hash
]
in
\
self
.
evictor
)):
num_touched_blocks
+=
1
if
num_lookahead_slots
>
block
.
num_empty_slots
:
num_touched_blocks
+=
cdiv
(
num_lookahead_slots
-
block
.
num_empty_slots
,
self
.
_block_size
)
else
:
# If the block has a match in the cache and the cached block
# is not referenced, then we still count it as a touched block
if
not
self
.
is_block_cached
(
block
)
or
\
(
block
.
content_hash
is
not
None
and
\
self
.
_cached_blocks
[
block
.
content_hash
]
in
self
.
evictor
):
num_touched_blocks
+=
1
return
num_touched_blocks
def
swap_out
(
self
,
blocks
:
List
[
Block
])
->
None
:
...
...
vllm/core/block_manager_v2.py
View file @
5bf8789b
"""A block manager that manages token blocks."""
from
itertools
import
chain
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
...
...
@@ -470,12 +469,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
AllocStatus: The AllocStatus for swapping in/out the given
sequence_group on to the 'device'.
"""
blocks
=
self
.
_get_blocks_for_swap
(
seq_group
,
status
)
num_blocks_touched
=
self
.
block_allocator
.
get_num_blocks_touched
(
blocks
,
device
,
num_lookahead_slots
)
# First determine the number of blocks that will be touched by this
# swap. Then verify if there are available blocks in the device
# to perform the swap.
num_blocks_touched
=
0
blocks
:
List
[
Block
]
=
[]
for
seq
in
seq_group
.
get_seqs
(
status
=
status
):
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
if
block_table
.
blocks
is
not
None
:
# Compute the number blocks to touch for the tokens to be
# appended. This does NOT include the full blocks that need
# to be touched for the swap.
num_blocks_touched
+=
\
block_table
.
get_num_blocks_touched_by_append_slots
(
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
num_lookahead_slots
=
num_lookahead_slots
)
blocks
.
extend
(
block_table
.
blocks
)
# Compute the number of full blocks to touch and add it to the
# existing count of blocks to touch.
num_blocks_touched
+=
self
.
block_allocator
.
get_num_full_blocks_touched
(
blocks
,
device
=
device
)
watermark_blocks
=
0
if
device
==
Device
.
GPU
:
watermark_blocks
=
self
.
watermark_blocks
if
self
.
block_allocator
.
get_num_total_blocks
(
device
)
<
num_blocks_touched
:
return
AllocStatus
.
NEVER
...
...
@@ -484,23 +502,3 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
AllocStatus
.
OK
else
:
return
AllocStatus
.
LATER
def
_get_blocks_for_swap
(
self
,
seq_group
:
SequenceGroup
,
status
:
SequenceStatus
)
->
List
[
Block
]:
"""Returns the list of blocks those are touched by the seq_group
Args:
sequence_group (SequenceGroup): The sequence group to swap in.
status (SequenceStatus): The status of sequence which is needed
for action. RUNNING for swap out and SWAPPED for swap in
Returns:
The list of blocks those are touched by the seq_group.
"""
blocks
:
Dict
[
int
,
List
[
Block
]]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
status
):
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
if
block_table
.
blocks
is
not
None
:
blocks
[
seq
.
seq_id
]
=
block_table
.
blocks
combined_blocks
=
list
(
chain
(
*
blocks
.
values
()))
return
combined_blocks
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment