Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3476ed08
Unverified
Commit
3476ed08
authored
Jul 01, 2024
by
Alexander Matveev
Committed by
GitHub
Jul 01, 2024
Browse files
[Core] Optimize block_manager_v2 vs block_manager_v1 (to make V2 default) (#5602)
parent
54600709
Changes
19
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
1189 additions
and
532 deletions
+1189
-532
benchmarks/benchmark_latency.py
benchmarks/benchmark_latency.py
+4
-0
tests/conftest.py
tests/conftest.py
+1
-1
tests/core/block/test_block_table.py
tests/core/block/test_block_table.py
+3
-2
tests/core/block/test_cpu_gpu_block_allocator.py
tests/core/block/test_cpu_gpu_block_allocator.py
+12
-12
tests/core/block/test_naive_block.py
tests/core/block/test_naive_block.py
+3
-3
tests/core/block/test_prefix_caching_block.py
tests/core/block/test_prefix_caching_block.py
+67
-39
tests/spec_decode/test_batch_expansion.py
tests/spec_decode/test_batch_expansion.py
+4
-4
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+56
-29
vllm/core/block/common.py
vllm/core/block/common.py
+154
-44
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+56
-28
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+44
-12
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+149
-67
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+479
-214
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+101
-49
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-1
vllm/entrypoints/openai/serving_completion.py
vllm/entrypoints/openai/serving_completion.py
+1
-1
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+2
-2
vllm/outputs.py
vllm/outputs.py
+2
-2
vllm/sequence.py
vllm/sequence.py
+47
-22
No files found.
benchmarks/benchmark_latency.py
View file @
3476ed08
...
...
@@ -46,6 +46,7 @@ def main(args: argparse.Namespace):
load_format
=
args
.
load_format
,
distributed_executor_backend
=
args
.
distributed_executor_backend
,
otlp_traces_endpoint
=
args
.
otlp_traces_endpoint
,
enable_prefix_caching
=
args
.
enable_prefix_caching
,
)
sampling_params
=
SamplingParams
(
...
...
@@ -220,6 +221,9 @@ if __name__ == '__main__':
action
=
'store_true'
,
help
=
'If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens'
)
parser
.
add_argument
(
"--enable-prefix-caching"
,
action
=
'store_true'
,
help
=
"Enable automatic prefix caching"
)
parser
.
add_argument
(
'--use-v2-block-manager'
,
action
=
'store_true'
)
parser
.
add_argument
(
"--ray-workers-use-nsight"
,
...
...
tests/conftest.py
View file @
3476ed08
...
...
@@ -474,7 +474,7 @@ class VllmRunner:
req_sample_output_strs
:
List
[
str
]
=
[]
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_ids
=
sample
.
token_ids
output_ids
=
list
(
sample
.
token_ids
)
req_sample_output_ids
.
append
(
prompt_ids
+
output_ids
)
req_sample_output_strs
.
append
(
prompt_str
+
output_str
)
outputs
.
append
((
req_sample_output_ids
,
req_sample_output_strs
))
...
...
tests/core/block/test_block_table.py
View file @
3476ed08
...
...
@@ -373,8 +373,9 @@ def test_cow(block_size: int, sequence_len: int, append_len: int,
block_size
)
-
(
sequence_len
//
block_size
)
original_block_table
.
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
original_block_ids
=
original_block_table
.
physical_block_ids
original_block_ids
=
original_block_table
.
physical_block_ids
[:]
print
(
"original_block_ids = {}"
.
format
(
original_block_ids
))
forked_block_table
=
original_block_table
.
fork
()
# Expect no additional allocation (copy on _write_).
...
...
@@ -457,7 +458,7 @@ def test_cow_lookahead_simple(block_size: int, sequence_len: int,
# Allocate lookahead slots.
original_block_table
.
ensure_num_empty_slots
(
lookahead_slots
)
original_block_ids
=
original_block_table
.
physical_block_ids
original_block_ids
=
original_block_table
.
physical_block_ids
[:]
forked_block_table
=
original_block_table
.
fork
()
...
...
tests/core/block/test_cpu_gpu_block_allocator.py
View file @
3476ed08
...
...
@@ -8,8 +8,8 @@ from vllm.utils import Device, chunk_list
@
pytest
.
mark
.
parametrize
(
"num_gpu_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_allocate_mutable
(
num_cpu_blocks
:
int
,
num_gpu_blocks
:
int
,
block_size
:
int
,
allocator_type
:
str
):
def
test_allocate_mutable
_block
(
num_cpu_blocks
:
int
,
num_gpu_blocks
:
int
,
block_size
:
int
,
allocator_type
:
str
):
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
...
...
@@ -21,14 +21,14 @@ def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int,
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
cpu_blocks
=
[
allocator
.
allocate_mutable
(
prev_block
=
None
,
device
=
Device
.
CPU
)
allocator
.
allocate_mutable
_block
(
prev_block
=
None
,
device
=
Device
.
CPU
)
for
_
in
range
(
num_cpu_blocks
)
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
0
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
gpu_blocks
=
[
allocator
.
allocate_mutable
(
prev_block
=
None
,
device
=
Device
.
GPU
)
allocator
.
allocate_mutable
_block
(
prev_block
=
None
,
device
=
Device
.
GPU
)
for
_
in
range
(
num_gpu_blocks
)
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
0
...
...
@@ -47,8 +47,8 @@ def test_allocate_mutable(num_cpu_blocks: int, num_gpu_blocks: int,
@
pytest
.
mark
.
parametrize
(
"num_gpu_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_allocate_immutable
(
num_cpu_blocks
:
int
,
num_gpu_blocks
:
int
,
block_size
:
int
,
allocator_type
:
str
):
def
test_allocate_immutable
_block
(
num_cpu_blocks
:
int
,
num_gpu_blocks
:
int
,
block_size
:
int
,
allocator_type
:
str
):
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
...
...
@@ -67,18 +67,18 @@ def test_allocate_immutable(num_cpu_blocks: int, num_gpu_blocks: int,
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
cpu_blocks
=
[
allocator
.
allocate_immutable
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
Device
.
CPU
)
allocator
.
allocate_immutable
_block
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
Device
.
CPU
)
for
token_ids
in
cpu_token_ids
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
0
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
gpu_blocks
=
[
allocator
.
allocate_immutable
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
allocator
.
allocate_immutable
_block
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
for
token_ids
in
gpu_token_ids
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
0
...
...
tests/core/block/test_naive_block.py
View file @
3476ed08
...
...
@@ -14,11 +14,11 @@ class TestNaiveBlockAllocator:
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
]):
if
allocate_type
==
"immutable"
:
allocate_block
=
lambda
:
allocator
.
allocate_immutable
(
allocate_block
=
lambda
:
allocator
.
allocate_immutable
_block
(
prev_block
=
prev_block
,
token_ids
=
token_ids
)
elif
allocate_type
==
"mutable"
:
allocate_block
=
lambda
:
allocator
.
allocate_mutable
(
prev
_block
=
prev_block
)
allocate_block
=
lambda
:
allocator
.
allocate_mutable_block
(
prev_block
=
prev_block
)
else
:
raise
ValueError
()
...
...
tests/core/block/test_prefix_caching_block.py
View file @
3476ed08
...
...
@@ -26,11 +26,10 @@ class TestPrefixCachingBlock:
token_ids
=
list
(
range
(
num_to_fill
))
mock_allocator
=
MagicMock
(
spec
=
PrefixCachingBlockAllocator
)
block_with_prev
=
PrefixCachingBlock
(
prev_block
=
None
,
token_ids
=
token_ids
,
block_size
=
block_size
,
prefix_caching_allocator
=
mock_allocator
)
block_with_prev
=
PrefixCachingBlock
(
prev_block
=
None
,
token_ids
=
token_ids
,
block_size
=
block_size
,
allocator
=
mock_allocator
)
if
is_curr_block_full
:
# Expect hash since block is full.
...
...
@@ -71,7 +70,7 @@ class TestPrefixCachingBlock:
prev_block
=
previous_block
,
token_ids
=
token_ids
,
block_size
=
block_size
,
prefix_caching_
allocator
=
mock_allocator
,
allocator
=
mock_allocator
,
)
if
is_curr_block_full
and
prev_block_has_hash
:
...
...
@@ -138,7 +137,7 @@ class TestPrefixCachingBlock:
prev_block
=
prev_block
,
token_ids
=
[],
block_size
=
block_size
,
prefix_caching_
allocator
=
allocator
,
allocator
=
allocator
,
)
tokens_to_append
=
token_ids
[
block_number
*
...
...
@@ -159,11 +158,11 @@ class TestPrefixCachingBlockAllocator:
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
]):
if
allocate_type
==
"immutable"
:
allocate_block
=
lambda
:
allocator
.
allocate_immutable
(
allocate_block
=
lambda
:
allocator
.
allocate_immutable
_block
(
prev_block
=
prev_block
,
token_ids
=
token_ids
)
elif
allocate_type
==
"mutable"
:
allocate_block
=
lambda
:
allocator
.
allocate_mutable
(
prev
_block
=
prev_block
)
allocate_block
=
lambda
:
allocator
.
allocate_mutable_block
(
prev_block
=
prev_block
)
else
:
raise
ValueError
()
...
...
@@ -233,12 +232,13 @@ class TestPrefixCachingBlockAllocator:
# Expect allocation with unseen hash to fail.
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocator
.
allocate_immutable
(
prev_block
=
chain
[
-
1
],
token_ids
=
list
(
range
(
block_size
)))
allocator
.
allocate_immutable_block
(
prev_block
=
chain
[
-
1
],
token_ids
=
list
(
range
(
block_size
)))
# Expect mutable allocation to fail.
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocator
.
allocate_mutable
(
prev_block
=
chain
[
-
1
])
allocator
.
allocate_mutable
_block
(
prev_block
=
chain
[
-
1
])
# Expect allocation of exact same chain to pass.
second_chain
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
...
...
@@ -270,7 +270,7 @@ class TestPrefixCachingBlockAllocator:
# Expect mutable allocation to fail.
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocator
.
allocate_mutable
(
prev_block
=
None
)
allocator
.
allocate_mutable
_block
(
prev_block
=
None
)
block_to_free
=
chain
[
-
1
]
...
...
@@ -280,11 +280,11 @@ class TestPrefixCachingBlockAllocator:
allocator
.
free
(
block_to_free
)
assert
block_to_free
.
block_id
is
None
,
i
new_block
=
allocator
.
allocate_mutable
(
prev_block
=
None
)
new_block
=
allocator
.
allocate_mutable
_block
(
prev_block
=
None
)
assert
new_block
.
block_id
==
block_id
,
i
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocator
.
allocate_mutable
(
prev_block
=
None
)
allocator
.
allocate_mutable
_block
(
prev_block
=
None
)
block_to_free
=
new_block
...
...
@@ -376,7 +376,6 @@ class TestPrefixCachingBlockAllocator:
# Create token ids that will exhaust all blocks.
token_ids
=
list
(
range
(
num_blocks_to_consume
*
block_size
))
blocks
=
list
(
range
(
num_blocks_to_consume
))
first_chain
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
...
...
@@ -384,9 +383,6 @@ class TestPrefixCachingBlockAllocator:
allocator
=
allocator
,
)
# mark all blocks in first chain as computed
allocator
.
mark_blocks_as_computed
(
blocks
)
# After zero_point, second_chain's token_ids would be set -1, which
# make it different from here comparing with first_chain
zero_point
=
random
.
randint
(
1
,
len
(
token_ids
)
-
1
)
...
...
@@ -424,15 +420,16 @@ class TestPrefixCachingBlockAllocator:
block_size
=
block_size
)
token_ids
=
list
(
range
(
block_size
))
block
=
allocator
.
allocate_immutable
(
prev_block
=
None
,
token_ids
=
token_ids
)
block
=
allocator
.
allocate_immutable
_block
(
prev_block
=
None
,
token_ids
=
token_ids
)
assert
allocator
.
_refcounter
.
get
(
block
.
block_id
)
==
1
m
=
allocator
.
allocate_mutable
(
prev_block
=
None
)
m
=
allocator
.
allocate_mutable
_block
(
prev_block
=
None
)
block_id
=
m
.
block_id
for
i
in
range
(
block_size
):
m
.
append_token_ids
([
i
])
# After block get promoted to immutable from mutable, if there is
# already same content hash block, then it shall be released into
# hashless_allocator
...
...
@@ -452,48 +449,79 @@ class TestPrefixCachingBlockAllocator:
all_blocks_list
=
[
i
for
i
in
range
(
num_blocks
)]
zero_ref
=
{
i
:
0
for
i
in
range
(
num_blocks
)}
one_ref
=
{
i
:
1
for
i
in
range
(
num_blocks
)}
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
token_ids
=
list
(
range
(
num_blocks
*
block_size
))
#
now we have num_blocks free blocks in hashless allocator
# with internal tracking list _blocks _cached_blocks and evictor
#
empty and
block
'
s re
f shall be 0
#
Verify initial/pre-alloc state
#
Ensure all
blocks
a
re
free inside hashless allocator
assert
list
(
allocator
.
_hashless_allocator
.
_free_block_indices
)
==
all_blocks_list
assert
len
(
allocator
.
_blocks
.
keys
())
==
0
# Ensure no tracked blocks
assert
len
(
allocator
.
_block_tracker
.
keys
())
==
num_blocks
for
block_id
in
range
(
num_blocks
):
assert
not
allocator
.
_block_tracker
[
block_id
].
active
# Ensure no cached blocks
assert
len
(
allocator
.
_cached_blocks
.
values
())
==
0
# Ensure no evicted blocks
assert
len
(
allocator
.
evictor
.
free_table
.
keys
())
==
0
# Ensure 0s ref counts for all blocks
assert
allocator
.
_refcounter
.
_refcounts
==
zero_ref
# Allocate immutable chains with only one block residuled in
new_block
=
[]
for
i
in
range
(
num_blocks
):
block
=
allocator
.
allocate_immutable
(
block
=
allocator
.
allocate_immutable
_block
(
prev_block
=
None
,
token_ids
=
token_ids
[
block_size
*
i
:
block_size
*
(
i
+
1
)])
new_block
.
append
(
block
)
# Verify post-alloc state
# Ensure no blocks are free inside hashless allocator
assert
(
len
(
allocator
.
_hashless_allocator
.
_free_block_indices
)
==
0
)
# Ensure all blocks are tracked
assert
len
(
allocator
.
_block_tracker
.
keys
())
==
num_blocks
for
block_id
in
range
(
num_blocks
):
assert
allocator
.
_block_tracker
[
block_id
].
active
# Ensure all blocks are cached (all promoted)
assert
len
(
allocator
.
_cached_blocks
.
values
())
==
num_blocks
# Ensure no evicted blocks
assert
len
(
allocator
.
evictor
.
free_table
.
keys
())
==
0
# Ensure 1s ref counts for all blocks
assert
allocator
.
_refcounter
.
_refcounts
==
one_ref
# Free all blocks, and now all blocks shall be in the evictor
# there shall be no tracking data left in _block
s
# there shall be no tracking data left in _block
_tracker
# all blocks shall be tracked in _cached_blocks
# all blocks' ref shall be zero
for
block
in
new_block
:
allocator
.
free
(
block
)
assert
len
(
allocator
.
_blocks
.
keys
())
==
0
# Verify post-free state
# Ensure no tracked blocks
assert
len
(
allocator
.
_block_tracker
.
keys
())
==
num_blocks
for
block_id
in
range
(
num_blocks
):
assert
not
allocator
.
_block_tracker
[
block_id
].
active
# Ensure no blocks in hashless allocator (all promoted)
assert
len
(
allocator
.
_hashless_allocator
.
_free_block_indices
)
==
0
# Ensure all blocks are cached
assert
list
(
allocator
.
_cached_blocks
.
values
())
==
all_blocks_list
# Ensure all blocks are inside the evictor
assert
list
(
allocator
.
evictor
.
free_table
.
keys
())
==
all_blocks_list
# Ensure 0s refcounts
assert
allocator
.
_refcounter
.
_refcounts
==
zero_ref
# Allocate a mutable block, and the first block shall be evicted
# and set its content hash into None, ref to 1
mutable
=
allocator
.
allocate_mutable
(
prev_block
=
None
)
mutable
=
allocator
.
allocate_mutable
_block
(
prev_block
=
None
)
assert
mutable
.
block_id
==
0
assert
mutable
.
content_hash
is
None
assert
0
in
allocator
.
_block
s
assert
allocator
.
_block
_tracker
[
0
].
active
assert
allocator
.
_refcounter
.
get
(
0
)
==
1
assert
0
not
in
allocator
.
_cached_blocks
assert
0
not
in
allocator
.
evictor
...
...
@@ -502,27 +530,27 @@ class TestPrefixCachingBlockAllocator:
# hashless allocator
allocator
.
free
(
mutable
)
assert
len
(
allocator
.
_block
s
.
keys
())
==
0
assert
not
allocator
.
_block
_tracker
[
0
].
active
assert
allocator
.
_refcounter
.
_refcounts
==
zero_ref
assert
0
not
in
allocator
.
_cached_blocks
assert
0
not
in
allocator
.
evictor
assert
0
in
allocator
.
_hashless_allocator
.
_free_block_indices
#
w
hen allocate immutable with first block_size tokens, we
#
W
hen allocate immutable with first block_size tokens, we
# shall get free block from hashless allocator, thus no block left
# in hashless
block
=
allocator
.
allocate_immutable
(
prev
_block
=
None
,
token_ids
=
token_ids
[:
block_size
])
block
=
allocator
.
allocate_immutable_block
(
prev_block
=
None
,
token_ids
=
token_ids
[:
block_size
])
assert
block
.
block_id
==
0
assert
len
(
allocator
.
_hashless_allocator
.
_free_block_indices
)
==
0
assert
0
in
allocator
.
_block
s
assert
allocator
.
_block
_tracker
[
0
].
active
assert
0
in
allocator
.
_cached_blocks
.
values
()
assert
allocator
.
_refcounter
.
get
(
0
)
==
1
assert
0
not
in
allocator
.
evictor
# allocate mutable block again, it shall be popped from evictor
mutable
=
allocator
.
allocate_mutable
(
prev_block
=
None
)
mutable
=
allocator
.
allocate_mutable
_block
(
prev_block
=
None
)
assert
len
(
allocator
.
_hashless_allocator
.
_free_block_indices
)
==
0
assert
mutable
.
block_id
not
in
allocator
.
evictor
.
free_table
assert
allocator
.
_refcounter
.
get
(
mutable
.
block_id
)
==
1
...
...
@@ -619,7 +647,7 @@ class TestPrefixCachingBlockAllocator:
block_token_ids
=
token_ids
[
block_number
*
block_size
:(
block_number
+
1
)
*
block_size
]
prev_block
=
allocator
.
allocate_immutable
(
prev_block
=
allocator
.
allocate_immutable
_block
(
prev_block
=
prev_block
,
token_ids
=
block_token_ids
)
blocks
.
append
(
prev_block
)
...
...
tests/spec_decode/test_batch_expansion.py
View file @
3476ed08
...
...
@@ -90,10 +90,10 @@ def test_create_single_target_seq_group_metadata(k: int):
assert
output
.
request_id
==
input_seq_group_metadata
.
request_id
assert
len
(
output
.
seq_data
)
==
1
assert
output
.
seq_data
[
target_seq_id
].
get_prompt_token_ids
(
)
==
prompt_tokens
assert
output
.
seq_data
[
target_seq_id
].
get_output_token_ids
(
)
==
prev_output_tokens
+
token_ids
assert
output
.
seq_data
[
target_seq_id
].
get_prompt_token_ids
(
)
==
tuple
(
prompt_tokens
)
assert
output
.
seq_data
[
target_seq_id
].
get_output_token_ids
(
)
==
tuple
(
prev_output_tokens
+
token_ids
)
assert
len
(
output
.
block_tables
)
==
1
assert
output
.
block_tables
[
...
...
vllm/core/block/block_table.py
View file @
3476ed08
from
typing
import
List
,
Optional
from
vllm.core.block.common
import
BlockList
from
vllm.core.block.interfaces
import
Block
,
DeviceAwareBlockAllocator
from
vllm.utils
import
Device
,
cdiv
,
chunk_list
...
...
@@ -47,12 +48,10 @@ class BlockTable:
self
.
_allocator
=
block_allocator
if
_blocks
is
None
:
_blocks
=
[]
self
.
_blocks
:
List
[
Block
]
=
_blocks
self
.
_blocks
:
Block
List
=
Block
List
(
_blocks
)
self
.
_max_block_sliding_window
=
max_block_sliding_window
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self
.
_num_full_slots
=
len
(
self
.
_get_all_token_ids
())
self
.
_num_full_slots
=
self
.
_get_num_token_ids
()
@
staticmethod
def
get_num_required_blocks
(
token_ids
:
List
[
int
],
block_size
:
int
)
->
int
:
...
...
@@ -88,11 +87,18 @@ class BlockTable:
"""
assert
not
self
.
_is_allocated
assert
token_ids
self
.
_blocks
=
self
.
_allocate_blocks_for_token_ids
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
device
)
blocks
=
self
.
_allocate_blocks_for_token_ids
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
device
)
self
.
update
(
blocks
)
self
.
_num_full_slots
=
len
(
token_ids
)
def
update
(
self
,
blocks
:
List
[
Block
])
->
None
:
"""Resets the table to the newly provided blocks
(with their corresponding block ids)
"""
self
.
_blocks
.
update
(
blocks
)
def
append_token_ids
(
self
,
token_ids
:
List
[
int
],
num_lookahead_slots
:
int
=
0
,
...
...
@@ -140,11 +146,11 @@ class BlockTable:
num_lookahead_slots
)
# Update the blocks with the new tokens
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
first_block_idx
=
self
.
_num_full_slots
//
self
.
_block_size
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
for
block
,
token_block
in
zip
(
blocks
,
token_blocks
):
block
.
append_token_ids
(
token_block
)
for
i
,
token_block
in
enumerate
(
token_blocks
):
self
.
_
block
s
.
append_token_ids
(
first_block_idx
+
i
,
token_block
)
self
.
_num_full_slots
+=
len
(
token_ids
)
...
...
@@ -174,8 +180,8 @@ class BlockTable:
for
_
in
range
(
blocks_to_allocate
):
assert
len
(
self
.
_blocks
)
>
0
self
.
_blocks
.
append
(
self
.
_allocator
.
allocate_mutable
(
prev
_block
=
self
.
_blocks
[
-
1
],
device
=
device
))
self
.
_allocator
.
allocate_mutable_block
(
prev_block
=
self
.
_blocks
[
-
1
],
device
=
device
))
def
fork
(
self
)
->
"BlockTable"
:
"""Creates a new BlockTable instance with a copy of the blocks from the
...
...
@@ -209,12 +215,12 @@ class BlockTable:
is set to `None`.
"""
assert
self
.
_is_allocated
for
block
in
self
.
_
blocks
:
for
block
in
self
.
blocks
:
self
.
_allocator
.
free
(
block
)
self
.
_blocks
=
[]
self
.
_blocks
.
reset
()
@
property
def
physical_block_ids
(
self
)
->
List
[
Optional
[
int
]
]
:
def
physical_block_ids
(
self
)
->
List
[
int
]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
...
...
@@ -228,7 +234,7 @@ class BlockTable:
BlockTable.
"""
assert
self
.
_is_allocated
return
[
block
.
block_id
for
block
in
self
.
_blocks
]
return
self
.
_blocks
.
ids
()
def
get_unseen_token_ids
(
self
,
sequence_token_ids
:
List
[
int
])
->
List
[
int
]:
"""Get the number of "unseen" tokens in the sequence.
...
...
@@ -253,17 +259,31 @@ class BlockTable:
token_ids
:
List
[
int
],
device
:
Device
)
->
List
[
Block
]:
blocks
:
List
[
Block
]
=
[]
for
block_token_ids
in
chunk_list
(
token_ids
,
self
.
_block_size
):
if
len
(
block_token_ids
)
==
self
.
_block_size
:
# If the block is full, create an immutable block.
prev_block
=
self
.
_allocator
.
allocate_immutable
(
prev_block
,
token_ids
=
block_token_ids
,
device
=
device
)
block_token_ids
=
[]
tail_token_ids
=
[]
for
cur_token_ids
in
chunk_list
(
token_ids
,
self
.
_block_size
):
if
len
(
cur_token_ids
)
==
self
.
_block_size
:
block_token_ids
.
append
(
cur_token_ids
)
else
:
# Else, partially fill a mutable block with token ids.
prev_block
=
self
.
_allocator
.
allocate_mutable
(
prev_block
=
prev_block
,
device
=
device
)
prev_block
.
append_token_ids
(
block_token_ids
)
blocks
.
append
(
prev_block
)
tail_token_ids
.
append
(
cur_token_ids
)
if
block_token_ids
:
blocks
.
extend
(
self
.
_allocator
.
allocate_immutable_blocks
(
prev_block
,
block_token_ids
=
block_token_ids
,
device
=
device
))
prev_block
=
blocks
[
-
1
]
if
tail_token_ids
:
assert
len
(
tail_token_ids
)
==
1
cur_token_ids
=
tail_token_ids
[
0
]
block
=
self
.
_allocator
.
allocate_mutable_block
(
prev_block
=
prev_block
,
device
=
device
)
block
.
append_token_ids
(
cur_token_ids
)
blocks
.
append
(
block
)
return
blocks
...
...
@@ -274,18 +294,25 @@ class BlockTable:
if
not
self
.
_is_allocated
:
return
token_ids
for
block
in
self
.
_
blocks
:
for
block
in
self
.
blocks
:
token_ids
.
extend
(
block
.
token_ids
)
return
token_ids
def
_get_num_token_ids
(
self
)
->
int
:
res
=
0
for
block
in
self
.
blocks
:
res
+=
len
(
block
.
token_ids
)
return
res
@
property
def
_is_allocated
(
self
)
->
bool
:
return
len
(
self
.
_blocks
)
>
0
@
property
def
blocks
(
self
)
->
Optional
[
List
[
Block
]
]
:
return
self
.
_blocks
def
blocks
(
self
)
->
List
[
Block
]:
return
self
.
_blocks
.
list
()
@
property
def
_num_empty_slots
(
self
)
->
int
:
...
...
vllm/core/block/common.py
View file @
3476ed08
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
,
Tuple
from
collections
import
deque
from
typing
import
Deque
,
Dict
,
Iterable
,
List
,
Optional
,
Protocol
,
Tuple
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
...
...
@@ -95,64 +96,40 @@ class CopyOnWriteTracker:
The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference
counting and block allocation.
conjunction with a RefCounter.
Args:
refcounter (RefCounter): The reference counter used to track block
reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
"""
def
__init__
(
self
,
refcounter
:
RefCounterProtocol
,
allocator
:
BlockAllocator
,
):
def
__init__
(
self
,
refcounter
:
RefCounterProtocol
):
self
.
_copy_on_writes
:
List
[
Tuple
[
BlockId
,
BlockId
]]
=
[]
self
.
_refcounter
=
refcounter
self
.
_allocator
=
allocator
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
def
is_appendable
(
self
,
block
:
Block
)
->
bool
:
"""Checks if the block is shared or not. If shared, then it cannot
be appended and needs to be duplicated via copy-on-write
"""
block_id
=
block
.
block_id
if
block_id
is
None
:
return
block_id
return
True
refcount
=
self
.
_refcounter
.
get
(
block_id
)
assert
refcount
!=
0
if
refcount
>
1
:
src_block_id
=
block_id
# Decrement refcount of the old block.
self
.
_allocator
.
free
(
block
)
# Allocate a fresh new block.
block_id
=
self
.
_allocator
.
allocate_mutable
(
prev_block
=
block
.
prev_block
).
block_id
return
refcount
<=
1
# Track src/dst copy.
assert
src_block_id
is
not
None
assert
block_id
is
not
None
self
.
_copy_on_writes
.
append
((
src_block_id
,
block_id
))
return
block_id
def
record_cow
(
self
,
src_block_id
:
Optional
[
BlockId
],
trg_block_id
:
Optional
[
BlockId
])
->
None
:
"""Records a copy-on-write operation from source to target block id
Args:
src_block_id (BlockId): The source block id from which to copy
the data
trg_block_id (BlockId): The target block id to which the data
is copied
"""
assert
src_block_id
is
not
None
assert
trg_block_id
is
not
None
self
.
_copy_on_writes
.
append
((
src_block_id
,
trg_block_id
))
def
clear_cows
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
"""Clears the copy-on-write tracking information and returns the current
...
...
@@ -172,6 +149,139 @@ class CopyOnWriteTracker:
return
cows
class
BlockPool
:
"""Used to pre-allocate block objects, in order to avoid excessive python
object allocations/deallocations.
The pool starts from "pool_size" objects and will increase to more objects
if necessary
Note that multiple block objects may point to the same physical block id,
which is why this pool is needed, so that it will be easier to support
prefix caching and more complicated sharing of physical blocks.
"""
def
__init__
(
self
,
block_size
:
int
,
create_block
:
Block
.
Factory
,
allocator
:
BlockAllocator
,
pool_size
:
int
):
self
.
_block_size
=
block_size
self
.
_create_block
=
create_block
self
.
_allocator
=
allocator
self
.
_pool_size
=
pool_size
assert
self
.
_pool_size
>=
0
self
.
_free_ids
:
Deque
[
int
]
=
deque
(
range
(
self
.
_pool_size
))
self
.
_pool
=
[]
for
i
in
range
(
self
.
_pool_size
):
self
.
_pool
.
append
(
self
.
_create_block
(
prev_block
=
None
,
token_ids
=
[],
block_size
=
self
.
_block_size
,
allocator
=
self
.
_allocator
,
block_id
=
None
))
def
increase_pool
(
self
):
"""Doubles the internal pool size
"""
cur_pool_size
=
self
.
_pool_size
new_pool_size
=
cur_pool_size
*
2
self
.
_pool_size
=
new_pool_size
self
.
_free_ids
+=
deque
(
range
(
cur_pool_size
,
new_pool_size
))
for
i
in
range
(
cur_pool_size
,
new_pool_size
):
self
.
_pool
.
append
(
self
.
_create_block
(
prev_block
=
None
,
token_ids
=
[],
block_size
=
self
.
_block_size
,
allocator
=
self
.
_allocator
,
block_id
=
None
))
def
init_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
block_size
:
int
,
physical_block_id
:
Optional
[
int
])
->
Block
:
if
len
(
self
.
_free_ids
)
==
0
:
self
.
increase_pool
()
assert
len
(
self
.
_free_ids
)
>
0
pool_id
=
self
.
_free_ids
.
popleft
()
block
=
self
.
_pool
[
pool_id
]
block
.
__init__
(
# type: ignore[misc]
prev_block
=
prev_block
,
token_ids
=
token_ids
,
block_size
=
block_size
,
allocator
=
block
.
_allocator
,
# type: ignore[attr-defined]
block_id
=
physical_block_id
)
block
.
pool_id
=
pool_id
# type: ignore[attr-defined]
return
block
def
free_block
(
self
,
block
:
Block
)
->
None
:
self
.
_free_ids
.
appendleft
(
block
.
pool_id
)
# type: ignore[attr-defined]
class
BlockList
:
"""This class is an optimization to allow fast-access to physical
block ids. It maintains a block id list that is updated with the
block list and this avoids the need to reconstruct the block id
list on every iteration of the block manager
"""
def
__init__
(
self
,
blocks
:
List
[
Block
]):
self
.
_blocks
:
List
[
Block
]
=
[]
self
.
_block_ids
:
List
[
int
]
=
[]
self
.
update
(
blocks
)
def
_add_block_id
(
self
,
block_id
:
Optional
[
BlockId
])
->
None
:
assert
block_id
is
not
None
self
.
_block_ids
.
append
(
block_id
)
def
_update_block_id
(
self
,
block_index
:
int
,
new_block_id
:
Optional
[
BlockId
])
->
None
:
assert
new_block_id
is
not
None
self
.
_block_ids
[
block_index
]
=
new_block_id
def
update
(
self
,
blocks
:
List
[
Block
]):
self
.
_blocks
=
blocks
# Cache block ids for fast query
self
.
_block_ids
=
[]
for
block
in
self
.
_blocks
:
self
.
_add_block_id
(
block
.
block_id
)
def
append_token_ids
(
self
,
block_index
:
int
,
token_ids
:
List
[
int
])
->
None
:
block
=
self
.
_blocks
[
block_index
]
prev_block_id
=
block
.
block_id
block
.
append_token_ids
(
token_ids
)
# CoW or promotion may update the internal block_id
if
prev_block_id
!=
block
.
block_id
:
self
.
_update_block_id
(
block_index
,
block
.
block_id
)
def
append
(
self
,
new_block
:
Block
):
self
.
_blocks
.
append
(
new_block
)
self
.
_add_block_id
(
new_block
.
block_id
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_blocks
)
def
__getitem__
(
self
,
block_index
:
int
)
->
Block
:
return
self
.
_blocks
[
block_index
]
def
__setitem__
(
self
,
block_index
:
int
,
new_block
:
Block
)
->
None
:
self
.
_blocks
[
block_index
]
=
new_block
self
.
_update_block_id
(
block_index
,
new_block
.
block_id
)
def
reset
(
self
):
self
.
_blocks
=
[]
self
.
_block_ids
=
[]
def
list
(
self
)
->
List
[
Block
]:
return
self
.
_blocks
def
ids
(
self
)
->
List
[
int
]:
return
self
.
_block_ids
def
get_all_blocks_recursively
(
last_block
:
Block
)
->
List
[
Block
]:
"""Retrieves all the blocks in a sequence starting from the last block.
...
...
vllm/core/block/cpu_gpu_block_allocator.py
View file @
3476ed08
...
...
@@ -113,11 +113,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def
allocate_or_get_null_block
(
self
)
->
Block
:
if
self
.
_null_block
is
None
:
self
.
_null_block
=
NullBlock
(
self
.
allocate_mutable
(
None
,
Device
.
GPU
))
self
.
allocate_mutable
_block
(
None
,
Device
.
GPU
))
return
self
.
_null_block
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
def
allocate_mutable
_block
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
"""Allocates a new mutable block on the specified device.
Args:
...
...
@@ -128,10 +128,31 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Returns:
Block: The newly allocated mutable block.
"""
return
self
.
_allocators
[
device
].
allocate_mutable
(
prev_block
)
return
self
.
_allocators
[
device
].
allocate_mutable
_block
(
prev_block
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]],
device
:
Optional
[
Device
])
->
List
[
Block
]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks.
device (Device): The device on which to allocate the new block.
Returns:
List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs.
"""
return
self
.
_allocators
[
device
].
allocate_immutable_blocks
(
prev_block
,
block_token_ids
)
def
allocate_immutable_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
...
...
@@ -146,7 +167,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return
self
.
_allocators
[
device
].
allocate_immutable
(
return
self
.
_allocators
[
device
].
allocate_immutable
_block
(
prev_block
,
token_ids
)
def
free
(
self
,
block
:
Block
)
->
None
:
...
...
@@ -161,7 +182,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
block_id
=
block
.
block_id
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
return
allocator
.
free
(
block
)
allocator
.
free
(
block
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
...
...
@@ -210,8 +231,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
"""
return
self
.
_allocators
[
device
].
get_physical_block_id
(
absolute_id
)
def
swap
(
self
,
blocks
:
List
[
Block
],
s
ou
rc
e
_device
:
Device
,
d
e
st_device
:
Device
)
->
Dict
[
int
,
int
]:
def
swap
(
self
,
blocks
:
List
[
Block
],
src_device
:
Device
,
dst_device
:
Device
)
->
Dict
[
int
,
int
]:
"""Execute the swap for the given blocks from source_device
on to dest_device, save the current swap mapping and append
them to the accumulated `self._swap_mapping` for each
...
...
@@ -219,23 +240,23 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
blocks: List of blocks to be swapped.
s
ou
rc
e
_device (Device): Device to swap the 'blocks' from.
d
e
st_device (Device): Device to swap the 'blocks' to.
src_device (Device): Device to swap the 'blocks' from.
dst_device (Device): Device to swap the 'blocks' to.
Returns:
Dict[int, int]: Swap mapping from source_device
on to dest_device.
"""
s
ou
rc
e
_block_ids
=
[
block
.
block_id
for
block
in
blocks
]
self
.
_allocators
[
s
ou
rc
e
_device
].
swap_out
(
blocks
)
self
.
_allocators
[
d
e
st_device
].
swap_in
(
blocks
)
d
e
st_block_ids
=
[
block
.
block_id
for
block
in
blocks
]
src_block_ids
=
[
block
.
block_id
for
block
in
blocks
]
self
.
_allocators
[
src_device
].
swap_out
(
blocks
)
self
.
_allocators
[
dst_device
].
swap_in
(
blocks
)
dst_block_ids
=
[
block
.
block_id
for
block
in
blocks
]
current_swap_mapping
:
Dict
[
int
,
int
]
=
{}
for
src
,
d
e
st
in
zip
(
s
ou
rc
e
_block_ids
,
d
e
st_block_ids
):
if
src
is
not
None
and
d
e
st
is
not
None
:
self
.
_swap_mapping
[
src
]
=
d
e
st
current_swap_mapping
[
src
]
=
d
e
st
for
src
_block_id
,
dst
_block_id
in
zip
(
src_block_ids
,
dst_block_ids
):
if
src
_block_id
is
not
None
and
dst
_block_id
is
not
None
:
self
.
_swap_mapping
[
src
_block_id
]
=
dst
_block_id
current_swap_mapping
[
src
_block_id
]
=
dst
_block_id
return
current_swap_mapping
def
get_num_blocks_touched
(
self
,
...
...
@@ -283,23 +304,25 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
(
block_ids
)
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
get_computed_block_ids
(
prev_computed_block_ids
,
block_ids
,
skip_last_block_id
)
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
get_common_computed_block_ids
(
seq_block_ids
)
computed_
seq_block_ids
)
@
property
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
return
frozenset
(
self
.
_block_ids_to_allocator
.
keys
())
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
raise
NotImplementedError
def
get_and_reset_swaps
(
self
)
->
List
[
Tuple
[
int
,
int
]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every
...
...
@@ -341,6 +364,11 @@ class NullBlock(Block):
def
token_ids
(
self
)
->
List
[
BlockId
]:
return
self
.
_proxy
.
token_ids
@
property
def
num_tokens_total
(
self
)
->
int
:
raise
NotImplementedError
(
"num_tokens_total is not used for null block"
)
@
property
def
num_empty_slots
(
self
)
->
BlockId
:
return
self
.
_proxy
.
num_empty_slots
...
...
vllm/core/block/interfaces.py
View file @
3476ed08
...
...
@@ -28,6 +28,13 @@ class Block(ABC):
def
token_ids
(
self
)
->
List
[
int
]:
pass
@
property
@
abstractmethod
def
num_tokens_total
(
self
)
->
int
:
"""The number of tokens till the current block (inclusive)
"""
pass
@
property
@
abstractmethod
def
num_empty_slots
(
self
)
->
int
:
...
...
@@ -92,12 +99,18 @@ class Block(ABC):
class
BlockAllocator
(
ABC
):
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
def
allocate_mutable
_block
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
pass
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
def
allocate_immutable_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
pass
@
abstractmethod
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]])
->
List
[
Block
]:
pass
@
abstractmethod
...
...
@@ -146,13 +159,19 @@ class BlockAllocator(ABC):
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
@
abstractmethod
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
pass
@
abstractmethod
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
@
abstractmethod
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
"
BlockId
"
]
:
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
BlockId
:
"""NOTE: This should not be used besides Block"""
pass
...
...
@@ -174,13 +193,20 @@ class BlockAllocator(ABC):
class
DeviceAwareBlockAllocator
(
ABC
):
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
def
allocate_mutable_block
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
pass
@
abstractmethod
def
allocate_immutable_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
pass
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]],
device
:
Device
)
->
List
[
Block
]:
pass
@
abstractmethod
...
...
@@ -217,9 +243,15 @@ class DeviceAwareBlockAllocator(ABC):
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
@
abstractmethod
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
pass
@
abstractmethod
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
@
abstractmethod
...
...
@@ -230,8 +262,8 @@ class DeviceAwareBlockAllocator(ABC):
pass
@
abstractmethod
def
swap
(
self
,
blocks
:
List
[
Block
],
s
ou
rc
e
_device
:
Device
,
d
e
st_device
:
Device
)
->
Dict
[
int
,
int
]:
def
swap
(
self
,
blocks
:
List
[
Block
],
src_device
:
Device
,
dst_device
:
Device
)
->
Dict
[
int
,
int
]:
pass
@
abstractmethod
...
...
vllm/core/block/naive_block.py
View file @
3476ed08
from
typing
import
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
collections
import
deque
from
typing
import
Deque
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Tuple
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
RefCounter
,
from
vllm.core.block.common
import
(
BlockPool
,
CopyOnWriteTracker
,
RefCounter
,
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.utils
import
cdiv
...
...
@@ -31,28 +32,39 @@ class NaiveBlockAllocator(BlockAllocator):
num_blocks
:
int
,
block_size
:
int
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
block_pool
:
Optional
[
BlockPool
]
=
None
,
):
if
block_ids
is
None
:
block_ids
=
range
(
num_blocks
)
self
.
_free_block_indices
:
Set
[
BlockId
]
=
set
(
block_ids
)
self
.
_free_block_indices
:
Deque
[
BlockId
]
=
deque
(
block_ids
)
self
.
_all_block_indices
=
frozenset
(
block_ids
)
assert
len
(
self
.
_all_block_indices
)
==
num_blocks
self
.
_refcounter
=
RefCounter
(
all_block_indices
=
self
.
_free_block_indices
)
self
.
_create_block
=
create_block
self
.
_block_size
=
block_size
self
.
_cow_tracker
=
CopyOnWriteTracker
(
refcounter
=
self
.
_refcounter
.
as_readonly
(),
allocator
=
self
,
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
refcounter
=
self
.
_refcounter
.
as_readonly
())
if
block_pool
is
None
:
extra_factor
=
4
# Pre-allocate "num_blocks * extra_factor" block objects.
# The "* extra_factor" is a buffer to allow more block objects
# than physical blocks
self
.
_block_pool
=
BlockPool
(
self
.
_block_size
,
create_block
,
self
,
num_blocks
*
extra_factor
)
else
:
# In this case, the block pool is provided by the caller,
# which means that there is most likely a need to share
# a block pool between allocators
self
.
_block_pool
=
block_pool
def
allocate_immutable_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
...
...
@@ -66,13 +78,36 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated immutable block.
"""
assert
device
is
None
block
=
self
.
allocate_mutable
(
prev_block
=
prev_block
)
block
=
self
.
allocate_mutable
_block
(
prev_block
=
prev_block
)
block
.
append_token_ids
(
token_ids
)
return
block
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]],
device
:
Optional
[
Device
]
=
None
)
->
List
[
Block
]:
assert
device
is
None
num_blocks
=
len
(
block_token_ids
)
block_ids
=
[]
for
i
in
range
(
num_blocks
):
block_ids
.
append
(
self
.
_allocate_block_id
())
blocks
=
[]
for
i
in
range
(
num_blocks
):
prev_block
=
self
.
_block_pool
.
init_block
(
prev_block
=
prev_block
,
token_ids
=
block_token_ids
[
i
],
block_size
=
self
.
_block_size
,
physical_block_id
=
block_ids
[
i
])
blocks
.
append
(
prev_block
)
return
blocks
def
allocate_mutable_block
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Optional
[
Device
]
=
None
)
->
Block
:
"""Allocates a new mutable block, linked to the previous block.
Args:
...
...
@@ -84,20 +119,39 @@ class NaiveBlockAllocator(BlockAllocator):
Block: The newly allocated mutable block.
"""
assert
device
is
None
block_id
=
self
.
_allocate_new_block_id
()
return
self
.
_create_block
(
prev_block
=
prev_block
,
token_ids
=
[],
block_id
=
block_id
,
block_size
=
self
.
_block_size
,
allocator
=
self
,
)
def
free
(
self
,
block
:
Block
)
->
None
:
assert
block
.
block_id
is
not
None
self
.
_free_block_id
(
block
.
block_id
)
block_id
=
self
.
_allocate_block_id
()
block
=
self
.
_block_pool
.
init_block
(
prev_block
=
prev_block
,
token_ids
=
[],
block_size
=
self
.
_block_size
,
physical_block_id
=
block_id
)
return
block
def
_allocate_block_id
(
self
)
->
BlockId
:
if
not
self
.
_free_block_indices
:
raise
BlockAllocator
.
NoFreeBlocksError
()
block_id
=
self
.
_free_block_indices
.
popleft
()
self
.
_refcounter
.
incr
(
block_id
)
return
block_id
def
_free_block_id
(
self
,
block
:
Block
)
->
None
:
block_id
=
block
.
block_id
assert
block_id
is
not
None
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
if
refcount
==
0
:
self
.
_free_block_indices
.
appendleft
(
block_id
)
block
.
block_id
=
None
def
free
(
self
,
block
:
Block
,
keep_block_object
:
bool
=
False
)
->
None
:
# Release the physical block id
self
.
_free_block_id
(
block
)
# Release the block object
if
not
keep_block_object
:
self
.
_block_pool
.
free_block
(
block
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
...
...
@@ -120,14 +174,13 @@ class NaiveBlockAllocator(BlockAllocator):
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
assert
refcount
!=
1
,
"can't fork free'd block"
forked_blocks
.
append
(
self
.
_create_block
(
prev_block
=
prev_block
,
token_ids
=
block
.
token_ids
,
block_id
=
block
.
block_id
,
block_size
=
self
.
_block_size
,
allocator
=
self
,
))
forked_block
=
self
.
_block_pool
.
init_block
(
prev_block
=
prev_block
,
token_ids
=
block
.
token_ids
,
block_size
=
self
.
_block_size
,
physical_block_id
=
block
.
block_id
)
forked_blocks
.
append
(
forked_block
)
prev_block
=
forked_blocks
[
-
1
]
return
forked_blocks
...
...
@@ -138,20 +191,6 @@ class NaiveBlockAllocator(BlockAllocator):
def
get_num_total_blocks
(
self
)
->
int
:
return
len
(
self
.
_all_block_indices
)
def
_allocate_new_block_id
(
self
)
->
BlockId
:
if
not
self
.
_free_block_indices
:
raise
BlockAllocator
.
NoFreeBlocksError
()
block_id
=
next
(
iter
(
self
.
_free_block_indices
))
self
.
_refcounter
.
incr
(
block_id
)
self
.
_free_block_indices
.
remove
(
block_id
)
return
block_id
def
_free_block_id
(
self
,
block_id
:
BlockId
)
->
None
:
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
if
refcount
==
0
:
self
.
_free_block_indices
.
add
(
block_id
)
def
get_physical_block_id
(
self
,
absolute_id
:
int
)
->
int
:
"""Returns the zero-offset block id on certain block allocator
given the absolute block id.
...
...
@@ -173,7 +212,7 @@ class NaiveBlockAllocator(BlockAllocator):
def
all_block_ids
(
self
)
->
FrozenSet
[
int
]:
return
self
.
_all_block_indices
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]
:
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
BlockId
:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
...
...
@@ -181,11 +220,22 @@ class NaiveBlockAllocator(BlockAllocator):
block (Block): The block to check for copy-on-write.
Returns:
Optional[
BlockId
]
: The block index of the new block if a copy-on
-write
operation was performed, or the original block index if
BlockId: The block index of the new block if a copy-on
-write
operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
src_block_id
=
block
.
block_id
assert
src_block_id
is
not
None
if
self
.
_cow_tracker
.
is_appendable
(
block
):
return
src_block_id
self
.
_free_block_id
(
block
)
trg_block_id
=
self
.
_allocate_block_id
()
self
.
_cow_tracker
.
record_cow
(
src_block_id
,
trg_block_id
)
return
trg_block_id
def
clear_copy_on_writes
(
self
)
->
List
[
Tuple
[
BlockId
,
BlockId
]]:
"""Returns the copy-on-write source->destination mapping and clears it.
...
...
@@ -213,8 +263,15 @@ class NaiveBlockAllocator(BlockAllocator):
"""
pass
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
"""No prefix caching here => return empty list
"""
return
[]
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
...
...
@@ -223,7 +280,7 @@ class NaiveBlockAllocator(BlockAllocator):
return
[]
def
promote_to_immutable_block
(
self
,
block
:
Block
)
->
BlockId
:
raise
NotImplementedError
raise
NotImplementedError
(
"There is no promotion for naive blocks"
)
def
get_num_blocks_touched
(
self
,
blocks
:
List
[
Block
],
...
...
@@ -263,17 +320,27 @@ class NaiveBlockAllocator(BlockAllocator):
def
swap_out
(
self
,
blocks
:
List
[
Block
])
->
None
:
for
block
in
blocks
:
self
.
free
(
block
)
self
.
_
free
_block_id
(
block
)
def
swap_in
(
self
,
blocks
:
List
[
Block
])
->
None
:
for
block
in
blocks
:
# Here we allocate either immutable or mutable block and then
# extract its block_id. Note that the block object is released
# and the block_id is assigned to "block" to allow reusing the
# existing "block" object
if
block
.
is_full
:
al
loc
=
self
.
allocate_immutable
(
block
.
prev
_block
,
block
.
token_ids
)
tmp_b
loc
k
=
self
.
allocate_immutable_block
(
prev_block
=
block
.
prev_block
,
token_ids
=
block
.
token_ids
)
else
:
alloc
=
self
.
allocate_mutable
(
block
.
prev_block
)
alloc
.
append_token_ids
(
block
.
token_ids
)
block
.
block_id
=
alloc
.
block_id
tmp_block
=
self
.
allocate_mutable_block
(
prev_block
=
block
.
prev_block
)
tmp_block
.
append_token_ids
(
block
.
token_ids
)
block_id
=
tmp_block
.
block_id
tmp_block
.
block_id
=
None
self
.
_block_pool
.
free_block
(
tmp_block
)
block
.
block_id
=
block_id
# Assign block_id
class
NaiveBlock
(
Block
):
...
...
@@ -315,11 +382,12 @@ class NaiveBlock(Block):
self
.
_append_token_ids_no_cow
(
token_ids
)
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
"""Appends the given token IDs to the block
, instructing the allocator
to perform a
copy-on-write if necessary.
"""Appends the given token IDs to the block
and performs a
copy-on-write if necessary.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
token_ids (Optional[List[int]]): The token IDs to be appended
to the block.
"""
self
.
_append_token_ids_no_cow
(
token_ids
)
...
...
@@ -328,7 +396,16 @@ class NaiveBlock(Block):
self
.
_cow_target
))
def
_append_token_ids_no_cow
(
self
,
token_ids
:
List
[
int
])
->
None
:
assert
self
.
num_empty_slots
>=
len
(
token_ids
)
"""Appends the given token IDs to the block
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
if
len
(
token_ids
)
==
0
:
return
assert
len
(
token_ids
)
<=
self
.
num_empty_slots
self
.
_token_ids
.
extend
(
token_ids
)
@
property
...
...
@@ -361,12 +438,17 @@ class NaiveBlock(Block):
@
property
def
num_empty_slots
(
self
)
->
int
:
return
self
.
_block_size
-
len
(
self
.
_
token_ids
)
return
self
.
_block_size
-
len
(
self
.
token_ids
)
@
property
def
token_ids
(
self
)
->
List
[
int
]:
return
self
.
_token_ids
@
property
def
num_tokens_total
(
self
)
->
int
:
raise
NotImplementedError
(
"num_tokens_total is not used for naive block"
)
@
property
def
block_size
(
self
)
->
int
:
return
self
.
_block_size
...
...
vllm/core/block/prefix_caching_block.py
View file @
3476ed08
This diff is collapsed.
Click to expand it.
vllm/core/block_manager_v2.py
View file @
3476ed08
...
...
@@ -7,6 +7,8 @@ from typing import Tuple
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.core.block.interfaces
import
Block
from
vllm.core.block.prefix_caching_block
import
(
ComputedBlocksTracker
,
LastAccessBlocksTracker
)
from
vllm.core.block.utils
import
check_no_caching_or_swa_for_blockmgr_encdec
from
vllm.core.interfaces
import
AllocStatus
,
BlockSpaceManager
from
vllm.sequence
import
Sequence
,
SequenceGroup
,
SequenceStatus
...
...
@@ -100,6 +102,11 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self
.
block_tables
:
Dict
[
SeqId
,
BlockTable
]
=
{}
self
.
cross_block_tables
:
Dict
[
EncoderSeqId
,
BlockTable
]
=
{}
self
.
_computed_blocks_tracker
=
ComputedBlocksTracker
(
self
.
block_allocator
)
self
.
_last_access_blocks_tracker
=
LastAccessBlocksTracker
(
self
.
block_allocator
)
def
can_allocate
(
self
,
seq_group
:
SequenceGroup
)
->
AllocStatus
:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
...
...
@@ -157,10 +164,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_table
:
BlockTable
=
self
.
_allocate_sequence
(
seq
)
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
# Track seq
self
.
_computed_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
# Assign the block table for each sequence.
for
seq
in
waiting_seqs
[
1
:]:
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
# Track seq
self
.
_computed_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
# Allocate cross-attention block table for encoder sequence
#
# NOTE: Here we assume that all sequences in the group have the same
...
...
@@ -224,11 +239,23 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
new_cows
def
free
(
self
,
seq
:
Sequence
)
->
None
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
seq_id
=
seq
.
seq_id
if
seq_id
not
in
self
.
block_tables
:
# Already freed or haven't been scheduled yet.
return
self
.
block_tables
[
seq
.
seq_id
].
free
()
del
self
.
block_tables
[
seq
.
seq_id
]
# Update seq block ids with the latest access time
self
.
_last_access_blocks_tracker
.
update_seq_blocks_last_access
(
seq_id
,
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
)
# Untrack seq
self
.
_last_access_blocks_tracker
.
remove_seq
(
seq_id
)
self
.
_computed_blocks_tracker
.
remove_seq
(
seq_id
)
# Free table/blocks
self
.
block_tables
[
seq_id
].
free
()
del
self
.
block_tables
[
seq_id
]
def
free_cross
(
self
,
seq_group
:
SequenceGroup
)
->
None
:
request_id
=
seq_group
.
request_id
...
...
@@ -239,9 +266,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
del
self
.
cross_block_tables
[
request_id
]
def
get_block_table
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
assert
seq
.
seq_id
in
self
.
block_tables
block_ids
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
assert
all
(
b
is
not
None
for
b
in
block_ids
)
return
block_ids
# type: ignore
def
get_cross_block_table
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
...
...
@@ -252,20 +277,14 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
block_ids
# type: ignore
def
access_all_blocks_in_seq
(
self
,
seq
:
Sequence
,
now
:
float
):
# Update the last accessed time of all the blocks accessed
# in this step.
# And the accessed time is only useful for prefix caching now,
# as it support internal evictor policy for which cached
# block could be refilled, to keep cached content could be reused
# at max extend.
if
self
.
enable_caching
:
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block
_
ids
:
List
[
Optional
[
int
]]
=
[]
for
block
_
id
in
block_table
.
physical_block_ids
:
block_ids
.
append
(
block_i
d
)
self
.
block_allocator
.
mark_blocks_as_accessed
(
block_ids
,
# type: ignore
now
)
# Record the latest access time for the sequence. The actual update
# of the
block
ids
is deferred to the sequence free(..) call, since
# only during freeing of
block
id
s, the blocks are actually added to
# the evictor (which is when the most updated time is require
d)
# (This avoids expensive calls to
mark_blocks_as_accessed(
..))
self
.
_last_access_blocks_tracker
.
update_last_access
(
seq
.
seq_id
,
now
)
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
# The only need for mark block as computed is for prefix caching,
...
...
@@ -285,17 +304,26 @@ class BlockSpaceManagerV2(BlockSpaceManager):
This method determines which blocks can be safely skipped for all
sequences in the sequence group.
"""
seq_block_ids
=
[
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
for
seq
in
seqs
]
computed_seq_block_ids
=
[]
for
seq
in
seqs
:
computed_seq_block_ids
.
append
(
self
.
_computed_blocks_tracker
.
get_cached_computed_blocks_and_update
(
seq
.
seq_id
,
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
))
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return
self
.
block_allocator
.
get_common_computed_block_ids
(
seq_block_ids
)
# type: ignore
computed_
seq_block_ids
)
# type: ignore
def
fork
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
src_block_table
=
self
.
block_tables
[
parent_seq
.
seq_id
]
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
# Track child seq
self
.
_computed_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
AllocStatus
:
"""Returns the AllocStatus for the given sequence_group
...
...
@@ -323,19 +351,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from CPU
to GPU.
"""
blocks
=
self
.
_get_blocks_for_swap
(
seq_group
,
SequenceStatus
.
SWAPPED
)
current_swap_mapping
=
self
.
block_allocator
.
swap
(
blocks
=
blocks
,
source_device
=
Device
.
CPU
,
dest_device
=
Device
.
GPU
)
block_number_mapping
=
{
self
.
block_allocator
.
get_physical_block_id
(
Device
.
CPU
,
cpu_block_id
):
self
.
block_allocator
.
get_physical_block_id
(
Device
.
GPU
,
gpu_block_id
)
for
cpu_block_id
,
gpu_block_id
in
current_swap_mapping
.
items
()
}
# convert to list of tuples once here
return
list
(
block_number_mapping
.
items
())
physical_block_id_mapping
=
[]
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
blocks
=
self
.
block_tables
[
seq
.
seq_id
].
blocks
if
len
(
blocks
)
==
0
:
continue
seq_swap_mapping
=
self
.
block_allocator
.
swap
(
blocks
=
blocks
,
src_device
=
Device
.
CPU
,
dst_device
=
Device
.
GPU
)
# Refresh the block ids of the table (post-swap)
self
.
block_tables
[
seq
.
seq_id
].
update
(
blocks
)
seq_physical_block_id_mapping
=
{
self
.
block_allocator
.
get_physical_block_id
(
Device
.
CPU
,
cpu_block_id
):
self
.
block_allocator
.
get_physical_block_id
(
Device
.
GPU
,
gpu_block_id
)
for
cpu_block_id
,
gpu_block_id
in
seq_swap_mapping
.
items
()
}
physical_block_id_mapping
.
extend
(
list
(
seq_physical_block_id_mapping
.
items
()))
return
physical_block_id_mapping
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
"""Returns whether we can swap out the given sequence_group
...
...
@@ -355,7 +395,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
True
return
False
def
swap_out
(
self
,
seq
uence
_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]]:
"""Returns the block id mapping (from GPU to CPU) generated by
swapping out the given sequence_group with num_lookahead_slots.
...
...
@@ -366,19 +406,31 @@ class BlockSpaceManagerV2(BlockSpaceManager):
List[Tuple[int, int]]: The mapping of swapping block from
GPU to CPU.
"""
blocks
=
self
.
_get_blocks_for_swap
(
sequence_group
,
SequenceStatus
.
RUNNING
)
current_swap_mapping
=
self
.
block_allocator
.
swap
(
blocks
=
blocks
,
source_device
=
Device
.
GPU
,
dest_device
=
Device
.
CPU
)
block_number_mapping
=
{
self
.
block_allocator
.
get_physical_block_id
(
Device
.
GPU
,
gpu_block_id
):
self
.
block_allocator
.
get_physical_block_id
(
Device
.
CPU
,
cpu_block_id
)
for
gpu_block_id
,
cpu_block_id
in
current_swap_mapping
.
items
()
}
# convert to list of tuples once here
return
list
(
block_number_mapping
.
items
())
physical_block_id_mapping
=
[]
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
blocks
=
self
.
block_tables
[
seq
.
seq_id
].
blocks
if
len
(
blocks
)
==
0
:
continue
seq_swap_mapping
=
self
.
block_allocator
.
swap
(
blocks
=
blocks
,
src_device
=
Device
.
GPU
,
dst_device
=
Device
.
CPU
)
# Refresh the block ids of the table (post-swap)
self
.
block_tables
[
seq
.
seq_id
].
update
(
blocks
)
seq_physical_block_id_mapping
=
{
self
.
block_allocator
.
get_physical_block_id
(
Device
.
GPU
,
gpu_block_id
):
self
.
block_allocator
.
get_physical_block_id
(
Device
.
CPU
,
cpu_block_id
)
for
gpu_block_id
,
cpu_block_id
in
seq_swap_mapping
.
items
()
}
physical_block_id_mapping
.
extend
(
list
(
seq_physical_block_id_mapping
.
items
()))
return
physical_block_id_mapping
def
get_num_free_gpu_blocks
(
self
)
->
int
:
return
self
.
block_allocator
.
get_num_free_blocks
(
Device
.
GPU
)
...
...
vllm/engine/llm_engine.py
View file @
3476ed08
...
...
@@ -177,7 +177,8 @@ class LLMEngine:
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s)"
,
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"enable_prefix_caching=%s)"
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
...
...
@@ -204,6 +205,8 @@ class LLMEngine:
observability_config
,
model_config
.
seed
,
model_config
.
served_model_name
,
scheduler_config
.
use_v2_block_manager
,
cache_config
.
enable_prefix_caching
,
)
# TODO(woosuk): Print more configs in debug mode.
...
...
vllm/entrypoints/openai/serving_completion.py
View file @
3476ed08
...
...
@@ -345,7 +345,7 @@ class OpenAIServingCompletion(OpenAIServing):
out_logprobs
=
prompt_logprobs
output_text
=
prompt_text
elif
request
.
echo
and
request
.
max_tokens
>
0
:
token_ids
=
prompt_token_ids
+
output
.
token_ids
token_ids
=
prompt_token_ids
+
list
(
output
.
token_ids
)
out_logprobs
=
(
prompt_logprobs
+
output
.
logprobs
if
request
.
logprobs
is
not
None
else
None
)
output_text
=
prompt_text
+
output
.
text
...
...
vllm/model_executor/sampling_metadata.py
View file @
3476ed08
...
...
@@ -427,8 +427,8 @@ class SamplingTensors:
if
seq_group
.
do_sample
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
prompt_tokens
.
append
(
list
(
seq_data
.
prompt_token_ids
)
)
output_tokens
.
append
(
list
(
seq_data
.
output_token_ids
)
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
...
...
vllm/outputs.py
View file @
3476ed08
import
time
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
...
...
@@ -28,7 +28,7 @@ class CompletionOutput:
index
:
int
text
:
str
token_ids
:
List
[
int
]
token_ids
:
Tuple
[
int
,
...
]
cumulative_logprob
:
float
logprobs
:
Optional
[
SampleLogprobs
]
finish_reason
:
Optional
[
str
]
=
None
...
...
vllm/sequence.py
View file @
3476ed08
...
...
@@ -116,41 +116,66 @@ class SequenceData:
prompt_token_ids
:
List
[
int
],
output_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
if
output_token_ids
is
None
:
output_token_ids
=
[]
self
.
_prompt_token_ids
:
List
[
int
]
=
list
(
prompt_token_ids
)
self
.
_prompt_token_ids_tuple
:
Tuple
[
int
,
...]
=
tuple
(
prompt_token_ids
)
self
.
_output_token_ids
:
List
[
int
]
=
(
list
(
output_token_ids
)
if
output_token_ids
is
not
None
else
[])
self
.
prompt_token_ids
=
prompt_token_ids
self
.
_prompt_token_ids_tuple
=
tuple
(
prompt_token_ids
)
self
.
output_token_ids
=
output_token_ids
self
.
cumulative_logprob
=
0.0
# The number of tokens that are computed (that run against the model).
self
.
_num_computed_tokens
=
0
self
.
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
self
.
_update_cached_all_tokens
()
def
_update_cached_all_tokens
(
self
):
self
.
_cached_all_token_ids
:
List
[
int
]
=
(
self
.
_prompt_token_ids
+
self
.
_output_token_ids
)
@
property
def
prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
_prompt_token_ids_tuple
@
prompt_token_ids
.
setter
def
prompt_token_ids
(
self
,
new_prompt_token_ids
)
->
None
:
self
.
_prompt_token_ids
=
list
(
new_prompt_token_ids
)
self
.
_prompt_token_ids_tuple
=
tuple
(
new_prompt_token_ids
)
self
.
_update_cached_all_tokens
()
@
property
def
output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
tuple
(
self
.
_output_token_ids
)
@
output_token_ids
.
setter
def
output_token_ids
(
self
,
new_output_token_ids
)
->
None
:
self
.
_output_token_ids
=
list
(
new_output_token_ids
)
self
.
_update_cached_all_tokens
()
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
output_token_ids
.
append
(
token_id
)
self
.
_output_token_ids
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
self
.
cumulative_logprob
+=
logprob
def
get_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
+
len
(
self
.
prompt_token_ids
)
return
len
(
self
.
_
output_token_ids
)
+
len
(
self
.
_
prompt_token_ids
)
def
get_prompt_len
(
self
)
->
int
:
return
len
(
self
.
prompt_token_ids
)
return
len
(
self
.
_
prompt_token_ids
)
def
get_output_len
(
self
)
->
int
:
return
len
(
self
.
output_token_ids
)
return
len
(
self
.
_
output_token_ids
)
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
output
_token_ids
return
self
.
_cached_all
_token_ids
def
get_prefix_token_ids
(
self
,
num_tokens
:
int
)
->
Tuple
[
Tuple
[
int
,
...],
Optional
[
Tuple
[
int
,
...]]]:
"""Get prefix tokens, and make the return value hashable"""
prompt_length
=
len
(
self
.
prompt_
token_ids
)
prompt_length
=
self
.
get_
prompt_
len
(
)
if
num_tokens
>
prompt_length
:
return
(
self
.
_prompt_token_ids_tuple
,
tuple
(
self
.
output_token_ids
[:
num_tokens
-
prompt_length
]))
tuple
(
self
.
_
output_token_ids
[:
num_tokens
-
prompt_length
]))
else
:
return
(
self
.
_prompt_token_ids_tuple
[:
num_tokens
],
None
)
...
...
@@ -183,14 +208,14 @@ class SequenceData:
return
self
.
get_len
()
-
self
.
get_num_computed_tokens
()
def
get_last_token_id
(
self
)
->
int
:
if
not
self
.
output_token_ids
:
return
self
.
prompt_token_ids
[
-
1
]
return
self
.
output_token_ids
[
-
1
]
if
not
self
.
_
output_token_ids
:
return
self
.
_
prompt_token_ids
[
-
1
]
return
self
.
_
output_token_ids
[
-
1
]
def
get_prompt_token_ids
(
self
)
->
List
[
int
]:
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...
]:
return
self
.
prompt_token_ids
def
get_output_token_ids
(
self
)
->
List
[
int
]:
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...
]:
return
self
.
output_token_ids
@
property
...
...
@@ -199,8 +224,8 @@ class SequenceData:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"output_token_ids=
{
self
.
output_token_ids
}
, "
f
"prompt_token_ids=
{
self
.
_
prompt_token_ids
}
, "
f
"output_token_ids=
{
self
.
_
output_token_ids
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
)"
)
...
...
@@ -306,14 +331,14 @@ class Sequence:
def
get_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
get_token_ids
()
def
get_prompt_token_ids
(
self
)
->
List
[
int
]:
def
get_prompt_token_ids
(
self
)
->
Tuple
[
int
,
...
]:
return
self
.
data
.
get_prompt_token_ids
()
def
get_last_token_id
(
self
)
->
int
:
return
self
.
data
.
get_last_token_id
()
def
get_output_token_ids
(
self
)
->
List
[
int
]:
return
self
.
data
.
output_token_ids
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...
]:
return
self
.
data
.
get_
output_token_ids
()
def
get_cumulative_logprob
(
self
)
->
float
:
return
self
.
data
.
cumulative_logprob
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment