Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14ccd94c
Unverified
Commit
14ccd94c
authored
Mar 27, 2024
by
Cade Daniel
Committed by
GitHub
Mar 27, 2024
Browse files
[Core][Bugfix]Refactor block manager for better testability (#3492)
parent
8267b06c
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2888 additions
and
42 deletions
+2888
-42
tests/core/block/__init__.py
tests/core/block/__init__.py
+0
-0
tests/core/block/e2e/conftest.py
tests/core/block/e2e/conftest.py
+56
-0
tests/core/block/e2e/test_correctness.py
tests/core/block/e2e/test_correctness.py
+86
-0
tests/core/block/test_block_space_manager.py
tests/core/block/test_block_space_manager.py
+50
-0
tests/core/block/test_block_table.py
tests/core/block/test_block_table.py
+500
-0
tests/core/block/test_common.py
tests/core/block/test_common.py
+42
-0
tests/core/block/test_cpu_gpu_block_allocator.py
tests/core/block/test_cpu_gpu_block_allocator.py
+93
-0
tests/core/block/test_naive_block.py
tests/core/block/test_naive_block.py
+102
-0
tests/core/block/test_prefix_caching_block.py
tests/core/block/test_prefix_caching_block.py
+384
-0
tests/core/test_block_manager.py
tests/core/test_block_manager.py
+40
-39
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+1
-1
tests/core/utils.py
tests/core/utils.py
+38
-1
tests/prefix_caching/test_prefix_caching.py
tests/prefix_caching/test_prefix_caching.py
+1
-1
vllm/config.py
vllm/config.py
+7
-0
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+245
-0
vllm/core/block/common.py
vllm/core/block/common.py
+185
-0
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+206
-0
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+105
-0
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+275
-0
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+472
-0
No files found.
tests/core/block/__init__.py
0 → 100644
View file @
14ccd94c
tests/core/block/e2e/conftest.py
0 → 100644
View file @
14ccd94c
import
contextlib
import
gc
import
pytest
import
ray
import
torch
from
vllm
import
LLM
from
vllm.model_executor.parallel_utils.parallel_state
import
(
destroy_model_parallel
)
from
vllm.model_executor.utils
import
set_random_seed
def
cleanup
():
destroy_model_parallel
()
with
contextlib
.
suppress
(
AssertionError
):
torch
.
distributed
.
destroy_process_group
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
ray
.
shutdown
()
@
pytest
.
fixture
def
baseline_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
):
return
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
baseline_llm_kwargs
,
seed
)
@
pytest
.
fixture
def
test_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
):
return
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
test_llm_kwargs
,
seed
)
def
create_llm_generator
(
common_llm_kwargs
,
per_test_common_llm_kwargs
,
distinct_llm_kwargs
,
seed
):
kwargs
=
{
**
common_llm_kwargs
,
**
per_test_common_llm_kwargs
,
**
distinct_llm_kwargs
,
}
def
generator_inner
():
llm
=
LLM
(
**
kwargs
)
set_random_seed
(
seed
)
yield
llm
del
llm
cleanup
()
for
llm
in
generator_inner
():
yield
llm
del
llm
tests/core/block/e2e/test_correctness.py
0 → 100644
View file @
14ccd94c
from
itertools
import
cycle
import
pytest
from
vllm
import
SamplingParams
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
# Use a small model for a fast test.
"model"
:
"facebook/opt-125m"
,
# skip cuda graph creation for fast test.
"enforce_eager"
:
True
,
# Allow only 5 sequences of ~1024 tokens in worst case.
"block_size"
:
16
,
"forced_num_gpu_blocks"
:
5
*
(
64
+
1
),
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"use_v2_block_manager"
:
False
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
10
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_v1_v2_greedy_equality_with_preemption
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
):
"""Verify block manager v2 produces same outputs as block manager v1, even
when there is preemption.
This constructs two LLM, each with limited number of GPU blocks. The limit
is decided such that as the sequences in the batch grow, sequences must be
preempted and removed from cache.
If the output token ids are equivalent, then we have confidence that the KV
cache is not corrupted in the v2 block manager.
NOTE: We want a significant number of generated tokens so that any incorrect
KV mapping has time to build up error.
"""
output_len
=
1024
temperature
=
0.0
# We want to ensure equality even with preemption.
# We force the total block size to be 1 + cdiv(output_len, block_size)
# so that only one sequence can fit at a time (once the sequences grow).
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
prompts
=
[
prompt
for
prompt
,
_
in
zip
(
cycle
(
prompts
),
range
(
batch_size
))]
sampling_params
=
SamplingParams
(
max_tokens
=
output_len
,
ignore_eos
=
True
,
temperature
=
temperature
,
)
print
(
'Getting token ids from block manager v1'
)
baseline_token_ids
=
get_token_ids_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
)
print
(
'Getting token ids from block manager v2'
)
test_token_ids
=
get_token_ids_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
for
expected_token_ids
,
actual_token_ids
in
zip
(
baseline_token_ids
,
test_token_ids
):
assert
expected_token_ids
==
actual_token_ids
assert
baseline_token_ids
==
test_token_ids
def
get_token_ids_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
):
for
llm
in
llm_generator
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
del
llm
return
token_ids
tests/core/block/test_block_space_manager.py
0 → 100644
View file @
14ccd94c
import
pytest
from
vllm.core.block_manager_v2
import
BlockSpaceManagerV2
from
vllm.core.interfaces
import
AllocStatus
from
..utils
import
create_seq_group
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_gpu_blocks"
,
[
8
,
40
,
80
])
@
pytest
.
mark
.
parametrize
(
"num_seqs_per_group"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"watermark"
,
[
0.0
,
0.5
])
def
test_can_allocate_seq_group
(
block_size
:
int
,
num_seqs_per_group
:
int
,
num_gpu_blocks
:
int
,
watermark
:
float
):
block_manager
=
BlockSpaceManagerV2
(
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
1024
,
watermark
=
watermark
,
)
num_watermark_blocks
=
int
(
watermark
*
num_gpu_blocks
)
num_output_blocks_per_seq
=
1
# NOTE: This should be num_output_blocks_per_seq * num_seqs_per_group, but
# the current implementation assumes all seqs are new prompts / don't have
# different output lens.
num_output_blocks
=
num_output_blocks_per_seq
for
num_prompt_blocks
in
range
(
1
,
num_gpu_blocks
-
num_output_blocks
):
seq_group
=
create_seq_group
(
seq_prompt_lens
=
block_size
*
num_prompt_blocks
,
seq_output_lens
=
[
block_size
*
num_output_blocks_per_seq
for
_
in
range
(
num_seqs_per_group
)
],
)
assert
num_prompt_blocks
+
num_output_blocks
<=
num_gpu_blocks
can_allocate_result
=
block_manager
.
can_allocate
(
seq_group
)
num_required_blocks
=
num_prompt_blocks
+
num_output_blocks
if
num_gpu_blocks
-
num_required_blocks
<
num_watermark_blocks
:
assert
can_allocate_result
==
AllocStatus
.
NEVER
elif
num_gpu_blocks
>=
num_required_blocks
:
assert
can_allocate_result
==
AllocStatus
.
OK
else
:
assert
can_allocate_result
==
AllocStatus
.
LATER
tests/core/block/test_block_table.py
0 → 100644
View file @
14ccd94c
import
pytest
from
vllm.core.block.block_table
import
BlockTable
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.utils
import
Device
,
cdiv
,
chunk_list
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
16
,
129
])
def
test_allocate_naive
(
block_size
:
int
,
sequence_len
:
int
):
"""Test the allocation of blocks using the naive allocator.
This test creates a CpuGpuBlockAllocator with the specified block size and
number of blocks. It then allocates multiple BlockTables with varying
sequence lengths and verifies that the number of free blocks decreases as
expected after each allocation.
"""
assert
block_size
>
1
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
"naive"
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
1024
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
num_blocks_per_alloc
=
len
(
list
(
chunk_list
(
token_ids
,
block_size
)))
block_tables
=
[]
for
i
in
range
(
5
):
assert
allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
==
num_gpu_blocks
-
i
*
num_blocks_per_alloc
block_tables
.
append
(
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
))
block_tables
[
-
1
].
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
16
,
129
])
def
test_allocate_prefix_caching
(
block_size
:
int
,
sequence_len
:
int
):
"""Test the allocation of blocks using the prefix caching allocator.
This test creates a CpuGpuBlockAllocator with the specified block size and
number of blocks, using the prefix caching allocator. It then allocates
multiple BlockTables with varying sequence lengths and verifies that the
number of free blocks decreases as expected after each allocation.
The test expects all sequences to share allocations, except for their last
block, which may be mutable. It calculates the expected number of immutable
and mutable blocks per allocation based on the sequence length and block
size.
"""
assert
block_size
>
1
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
"prefix_caching"
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
1024
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
chunked_tokens
=
list
(
chunk_list
(
token_ids
,
block_size
))
num_mutable_blocks_per_alloc
=
0
if
len
(
chunked_tokens
[
-
1
])
==
block_size
else
1
num_immutable_blocks_per_alloc
=
len
(
chunked_tokens
)
-
num_mutable_blocks_per_alloc
block_tables
=
[]
for
alloc_i
in
range
(
1
,
6
):
block_tables
.
append
(
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
))
block_tables
[
-
1
].
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
# Expect all sequences to share allocations, except for their last block
# (which may be mutable).
assert
allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
==
num_gpu_blocks
-
(
num_immutable_blocks_per_alloc
+
num_mutable_blocks_per_alloc
*
(
alloc_i
))
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"gpu"
])
def
test_allocate_free
(
block_size
:
int
,
sequence_len
:
int
,
allocator_type
:
str
,
device
:
str
):
"""Test the allocation and freeing of blocks using different allocators and
devices.
This test creates a CpuGpuBlockAllocator with the specified block size,
number of blocks, allocator type, and device. It then allocates a BlockTable
multiple times with the same sequence and verifies that the number of free
blocks remains consistent after each allocation and freeing.
"""
device
=
Device
[
device
.
upper
()]
num_device_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_device_blocks
,
num_cpu_blocks
=
num_device_blocks
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
num_blocks_per_alloc
=
len
(
list
(
chunk_list
(
token_ids
,
block_size
)))
block_table
=
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
)
for
i
in
range
(
5
):
block_table
.
allocate
(
token_ids
=
token_ids
,
device
=
device
)
assert
allocator
.
get_num_free_blocks
(
device
)
==
num_device_blocks
-
num_blocks_per_alloc
assert
all
(
block_id
is
not
None
for
block_id
in
block_table
.
physical_block_ids
)
block_table
.
free
()
assert
allocator
.
get_num_free_blocks
(
device
)
==
num_device_blocks
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"append_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_append_token_ids_allocation
(
block_size
:
int
,
sequence_len
:
int
,
append_len
:
int
,
allocator_type
:
str
):
"""Test the allocation behavior when appending token IDs to a BlockTable.
This test creates a CpuGpuBlockAllocator with the specified block size,
number of blocks, and allocator type. It then allocates a BlockTable with an
initial sequence and appends additional token IDs to it. The test verifies
that the number of allocated blocks before and after appending matches the
expected values.
"""
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
1024
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
token_ids_to_append
=
list
(
range
(
append_len
))
block_table
=
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
)
num_expected_blocks_before_append
=
len
(
list
(
chunk_list
(
token_ids
,
block_size
)))
num_expected_appended_blocks
=
len
(
list
(
chunk_list
(
token_ids
+
token_ids_to_append
,
block_size
)))
-
num_expected_blocks_before_append
block_table
.
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
assert
len
(
block_table
.
physical_block_ids
)
==
num_expected_blocks_before_append
block_table
.
append_token_ids
(
token_ids_to_append
)
assert
len
(
block_table
.
physical_block_ids
)
==
num_expected_blocks_before_append
+
num_expected_appended_blocks
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"num_empty_slots"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_ensure_num_empty_slots_allocation
(
block_size
:
int
,
sequence_len
:
int
,
num_empty_slots
:
int
,
allocator_type
:
str
):
"""Test the allocation behavior when ensuring a certain number of empty
slots in a BlockTable.
This test creates a CpuGpuBlockAllocator with the specified block size,
number of blocks, and allocator type. It then allocates a BlockTable with an
initial sequence and ensures a certain number of empty slots. The test
verifies that the number of allocated blocks before and after ensuring empty
slots matches the expected values. It also checks that filling up the empty
slots does not consume additional blocks.
"""
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
1024
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
block_table
=
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
)
num_expected_blocks_before_append
=
len
(
list
(
chunk_list
(
token_ids
,
block_size
)))
num_expected_appended_blocks
=
len
(
list
(
chunk_list
(
token_ids
+
[
-
1
]
*
num_empty_slots
,
block_size
)))
-
num_expected_blocks_before_append
block_table
.
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
# Assert that the empty slots consume the expected number of additional
# blocks.
assert
len
(
block_table
.
physical_block_ids
)
==
num_expected_blocks_before_append
block_table
.
ensure_num_empty_slots
(
num_empty_slots
)
assert
len
(
block_table
.
physical_block_ids
)
==
num_expected_blocks_before_append
+
num_expected_appended_blocks
# Now, ensure no additional blocks consumed as we fill up the empty slots.
num_free_blocks
=
allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
block_table
.
append_token_ids
(
token_ids
=
list
(
range
(
num_empty_slots
)))
assert
num_free_blocks
==
allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
9
])
@
pytest
.
mark
.
parametrize
(
"append_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"append_size"
,
[
1
,
4
,
129
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_append_token_ids_correct_content
(
block_size
:
int
,
sequence_len
:
int
,
append_len
:
int
,
allocator_type
:
str
,
append_size
:
int
):
"""Verify token ids are correctly appended. Appends various amounts of
token ids in various append sizes, and verifies the final sequence is
correct.
"""
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
1024
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
token_ids_to_append
=
list
(
range
(
append_len
))
block_table
=
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
)
block_table
.
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
appended_so_far
=
[]
for
append
in
chunk_list
(
token_ids_to_append
,
append_size
):
block_table
.
append_token_ids
(
append
)
appended_so_far
.
extend
(
append
)
assert
block_table
.
_get_all_token_ids
()
==
token_ids
+
appended_so_far
assert
block_table
.
_get_all_token_ids
()
==
token_ids
+
token_ids_to_append
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
1
,
9
,
129
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
8
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_fork
(
seq_len
:
int
,
block_size
:
int
,
allocator_type
:
str
):
"""Create a sequence using the specified allocator.
1. Assert that after forking the sequence, the free block count is the
same.
2. Assert that the forked sequence has the same physical mappings.
3. Then free the original sequence; verify that the free block count is
the same.
4. Finally, free the forked sequence and verify that the free block
count drops to zero.
"""
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
0
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
seq_len
))
block_table
=
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
)
block_table
.
allocate
(
token_ids
)
num_free_blocks_before_fork
=
allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
forked_block_table
=
block_table
.
fork
()
# Expect physical_block_ids and token_ids to match.
assert
(
block_table
.
physical_block_ids
==
forked_block_table
.
physical_block_ids
)
assert
block_table
.
_get_all_token_ids
(
)
==
forked_block_table
.
_get_all_token_ids
()
# Do not expect any additional allocations.
assert
allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
==
num_free_blocks_before_fork
# Free the original blocks. Assert num free blocks does not change, since
# refcount is nonzero.
block_table
.
free
()
assert
allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
==
num_free_blocks_before_fork
# Expect the forked block table to be unaffected by the free.
assert
all
(
block_id
is
not
None
for
block_id
in
forked_block_table
.
physical_block_ids
)
# Free the forked blocks. Assert num free blocks does change, since
# refcount is now zero.
forked_block_table
.
free
()
assert
allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
==
num_gpu_blocks
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"append_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"appender"
,
[
"forked"
,
"original"
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_cow
(
block_size
:
int
,
sequence_len
:
int
,
append_len
:
int
,
allocator_type
:
str
,
appender
:
str
):
"""Fork a sequence; append to the forked sequence; verify there's a CoW.
"""
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
0
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
token_ids_to_append
=
list
(
range
(
append_len
))
original_block_table
=
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
)
num_expected_non_cow_blocks
=
cdiv
(
sequence_len
,
block_size
)
num_expected_cow_blocks
=
cdiv
(
sequence_len
+
append_len
,
block_size
)
-
(
sequence_len
//
block_size
)
original_block_table
.
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
original_block_ids
=
original_block_table
.
physical_block_ids
forked_block_table
=
original_block_table
.
fork
()
# Expect no additional allocation (copy on _write_).
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
(
num_gpu_blocks
-
num_expected_non_cow_blocks
)
if
appender
==
"forked"
:
appender_block_table
=
forked_block_table
static_block_table
=
original_block_table
elif
appender
==
"original"
:
appender_block_table
=
original_block_table
static_block_table
=
forked_block_table
else
:
raise
ValueError
(
f
"unknown test config
{
appender
=
}
"
)
# Write tokens.
appender_block_table
.
append_token_ids
(
token_ids_to_append
)
# Expect the non-appending block table to have no change.
assert
static_block_table
.
physical_block_ids
==
original_block_ids
assert
appender_block_table
.
physical_block_ids
!=
original_block_ids
# Expect the blocks changed during append to have a CoW.
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
-
(
num_expected_non_cow_blocks
+
num_expected_cow_blocks
)
cows
=
allocator
.
clear_copy_on_writes
()
if
sequence_len
%
block_size
>
0
:
# If the last block in the sequence is not full, then when appending we
# expect a CoW.
assert
cows
cow_block_id
=
sequence_len
//
block_size
expected_src
=
static_block_table
.
physical_block_ids
[
cow_block_id
]
expected_dst
=
appender_block_table
.
physical_block_ids
[
cow_block_id
]
assert
expected_src
in
cows
assert
expected_dst
in
cows
[
expected_src
]
else
:
# Otherwise, there should be no copy-on-write.
assert
not
cows
static_block_table
.
free
()
appender_block_table
.
free
()
# After free, expect all blocks to be freed.
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"sequence_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"append_len"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"lookahead_slots"
,
[
1
,
16
,
129
])
@
pytest
.
mark
.
parametrize
(
"appender"
,
[
"forked"
,
"original"
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_cow_lookahead_simple
(
block_size
:
int
,
sequence_len
:
int
,
append_len
:
int
,
lookahead_slots
:
int
,
allocator_type
:
str
,
appender
:
str
):
"""Similar to test_cow, except with lookahead allocation. The assertions are
less rigorous due to the complexity of the property under test.
"""
num_gpu_blocks
=
1024
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
0
,
block_size
=
block_size
,
)
token_ids
=
list
(
range
(
sequence_len
))
token_ids_to_append
=
list
(
range
(
append_len
))
original_block_table
=
BlockTable
(
block_size
=
block_size
,
block_allocator
=
allocator
,
)
original_block_table
.
allocate
(
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
# Allocate lookahead slots.
original_block_table
.
ensure_num_empty_slots
(
lookahead_slots
)
original_block_ids
=
original_block_table
.
physical_block_ids
forked_block_table
=
original_block_table
.
fork
()
if
appender
==
"forked"
:
appender_block_table
=
forked_block_table
static_block_table
=
original_block_table
elif
appender
==
"original"
:
appender_block_table
=
original_block_table
static_block_table
=
forked_block_table
else
:
raise
ValueError
(
f
"unknown test config
{
appender
=
}
"
)
# Write tokens.
appender_block_table
.
append_token_ids
(
token_ids_to_append
)
# Expect the non-appending block table to have no change.
assert
static_block_table
.
physical_block_ids
==
original_block_ids
assert
appender_block_table
.
physical_block_ids
!=
original_block_ids
cows
=
allocator
.
clear_copy_on_writes
()
# Always expect copy-on-write
assert
cows
if
sequence_len
%
block_size
>
0
:
# If the last block in the sequence is not full, then when appending we
# expect a CoW.
assert
cows
cow_block_id
=
sequence_len
//
block_size
expected_src
=
static_block_table
.
physical_block_ids
[
cow_block_id
]
expected_dst
=
appender_block_table
.
physical_block_ids
[
cow_block_id
]
assert
expected_src
in
cows
assert
expected_dst
in
cows
[
expected_src
]
static_block_table
.
free
()
appender_block_table
.
free
()
# After free, expect all blocks to be freed.
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
tests/core/block/test_common.py
0 → 100644
View file @
14ccd94c
import
random
import
pytest
from
vllm.core.block.common
import
RefCounter
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
20
)))
@
pytest
.
mark
.
parametrize
(
"num_incrs"
,
[
1
,
100
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
def
test_incr
(
seed
:
int
,
num_incrs
:
int
,
num_blocks
:
int
):
random
.
seed
(
seed
)
all_block_indices
=
list
(
range
(
num_blocks
))
counter
=
RefCounter
(
all_block_indices
=
all_block_indices
)
block_id
=
random
.
randint
(
0
,
num_blocks
-
1
)
for
i
in
range
(
num_incrs
):
value
=
counter
.
incr
(
block_id
)
assert
value
==
i
+
1
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
20
)))
@
pytest
.
mark
.
parametrize
(
"num_incrs"
,
[
1
,
100
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
def
test_incr_decr
(
seed
:
int
,
num_incrs
:
int
,
num_blocks
:
int
):
random
.
seed
(
seed
)
all_block_indices
=
list
(
range
(
num_blocks
))
counter
=
RefCounter
(
all_block_indices
=
all_block_indices
)
block_id
=
random
.
randint
(
0
,
num_blocks
-
1
)
for
i
in
range
(
num_incrs
):
value
=
counter
.
incr
(
block_id
)
assert
value
==
i
+
1
for
i
in
range
(
num_incrs
):
value
=
counter
.
decr
(
block_id
)
assert
value
==
num_incrs
-
(
i
+
1
)
with
pytest
.
raises
(
AssertionError
):
counter
.
decr
(
block_id
)
tests/core/block/test_cpu_gpu_block_allocator.py
0 → 100644
View file @
14ccd94c
import
pytest
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.utils
import
Device
,
chunk_list
@
pytest
.
mark
.
parametrize
(
"num_cpu_blocks"
,
[
0
,
512
])
@
pytest
.
mark
.
parametrize
(
"num_gpu_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_allocate_mutable
(
num_cpu_blocks
:
int
,
num_gpu_blocks
:
int
,
block_size
:
int
,
allocator_type
:
str
):
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
block_size
=
block_size
,
)
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
num_cpu_blocks
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
cpu_blocks
=
[
allocator
.
allocate_mutable
(
prev_block
=
None
,
device
=
Device
.
CPU
)
for
_
in
range
(
num_cpu_blocks
)
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
0
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
gpu_blocks
=
[
allocator
.
allocate_mutable
(
prev_block
=
None
,
device
=
Device
.
GPU
)
for
_
in
range
(
num_gpu_blocks
)
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
0
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
0
_
=
[
allocator
.
free
(
block
)
for
block
in
cpu_blocks
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
num_cpu_blocks
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
0
_
=
[
allocator
.
free
(
block
)
for
block
in
gpu_blocks
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
num_cpu_blocks
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
@
pytest
.
mark
.
parametrize
(
"num_cpu_blocks"
,
[
0
,
512
])
@
pytest
.
mark
.
parametrize
(
"num_gpu_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"allocator_type"
,
[
"naive"
,
"prefix_caching"
])
def
test_allocate_immutable
(
num_cpu_blocks
:
int
,
num_gpu_blocks
:
int
,
block_size
:
int
,
allocator_type
:
str
):
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
allocator_type
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
num_cpu_blocks
,
block_size
=
block_size
,
)
unique_token_ids
=
list
(
range
((
num_cpu_blocks
+
num_gpu_blocks
)
*
block_size
))
gpu_token_ids
=
chunk_list
(
unique_token_ids
[:
num_gpu_blocks
*
block_size
],
block_size
)
cpu_token_ids
=
chunk_list
(
unique_token_ids
[
num_gpu_blocks
*
block_size
:],
block_size
)
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
num_cpu_blocks
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
cpu_blocks
=
[
allocator
.
allocate_immutable
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
Device
.
CPU
)
for
token_ids
in
cpu_token_ids
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
0
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
gpu_blocks
=
[
allocator
.
allocate_immutable
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
Device
.
GPU
)
for
token_ids
in
gpu_token_ids
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
0
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
0
_
=
[
allocator
.
free
(
block
)
for
block
in
cpu_blocks
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
num_cpu_blocks
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
0
_
=
[
allocator
.
free
(
block
)
for
block
in
gpu_blocks
]
assert
allocator
.
get_num_free_blocks
(
Device
.
CPU
)
==
num_cpu_blocks
assert
allocator
.
get_num_free_blocks
(
Device
.
GPU
)
==
num_gpu_blocks
tests/core/block/test_naive_block.py
0 → 100644
View file @
14ccd94c
from
typing
import
List
,
Optional
import
pytest
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
class
TestNaiveBlockAllocator
:
@
staticmethod
def
create_allocate_lambda
(
allocate_type
:
str
,
allocator
:
NaiveBlockAllocator
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
]):
if
allocate_type
==
"immutable"
:
allocate_block
=
lambda
:
allocator
.
allocate_immutable
(
prev_block
=
prev_block
,
token_ids
=
token_ids
)
elif
allocate_type
==
"mutable"
:
allocate_block
=
lambda
:
allocator
.
allocate_mutable
(
prev_block
=
prev_block
)
else
:
raise
ValueError
()
return
allocate_block
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"allocate_type"
,
[
"immutable"
,
"mutable"
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
])
def
test_allocate_ooms
(
allocate_type
:
str
,
num_blocks
:
int
,
block_size
:
int
):
allocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
num_blocks
=
num_blocks
,
block_size
=
block_size
)
allocate_block
=
TestNaiveBlockAllocator
.
create_allocate_lambda
(
allocate_type
,
allocator
,
prev_block
=
None
,
token_ids
=
list
(
range
(
block_size
)))
[
allocate_block
()
for
_
in
range
(
num_blocks
)]
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocate_block
()
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"allocate_type"
,
[
"immutable"
,
"mutable"
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
])
def
test_free_prevents_oom
(
allocate_type
:
str
,
num_blocks
:
int
,
block_size
:
int
):
allocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
num_blocks
=
num_blocks
,
block_size
=
block_size
)
allocate_block
=
TestNaiveBlockAllocator
.
create_allocate_lambda
(
allocate_type
,
allocator
,
prev_block
=
None
,
token_ids
=
list
(
range
(
block_size
)))
blocks
=
[
allocate_block
()
for
_
in
range
(
num_blocks
)]
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocate_block
()
block_to_free
=
blocks
.
pop
()
for
_
in
range
(
100
):
block_id
=
block_to_free
.
block_id
allocator
.
free
(
block_to_free
)
assert
block_to_free
.
block_id
is
None
new_block
=
allocate_block
()
assert
new_block
.
block_id
==
block_id
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocate_block
()
block_to_free
=
new_block
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"allocate_type"
,
[
"immutable"
,
"mutable"
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_get_num_free_blocks
(
allocate_type
:
str
,
num_blocks
:
int
,
block_size
:
int
):
allocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
num_blocks
=
num_blocks
,
block_size
=
block_size
)
allocate_block
=
TestNaiveBlockAllocator
.
create_allocate_lambda
(
allocate_type
,
allocator
,
prev_block
=
None
,
token_ids
=
list
(
range
(
block_size
)))
assert
allocator
.
get_num_free_blocks
()
==
num_blocks
blocks
=
[
allocate_block
()
for
_
in
range
(
num_blocks
)]
for
i
,
block
in
enumerate
(
blocks
):
assert
allocator
.
get_num_free_blocks
()
==
i
allocator
.
free
(
block
)
tests/core/block/test_prefix_caching_block.py
0 → 100644
View file @
14ccd94c
import
math
import
random
from
typing
import
List
,
Optional
from
unittest.mock
import
MagicMock
import
pytest
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.prefix_caching_block
import
(
PrefixCachingBlock
,
PrefixCachingBlockAllocator
)
class
TestPrefixCachingBlock
:
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
])
@
pytest
.
mark
.
parametrize
(
"is_curr_block_full"
,
[
True
,
False
])
def
test_first_block_has_correct_content_hash
(
seed
:
int
,
block_size
:
int
,
is_curr_block_full
:
bool
):
"""Verify a block which is first in the sequence has the correct hash.
"""
random
.
seed
(
seed
)
num_to_fill
=
block_size
if
is_curr_block_full
else
random
.
randint
(
0
,
block_size
-
1
)
token_ids
=
list
(
range
(
num_to_fill
))
mock_allocator
=
MagicMock
(
spec
=
PrefixCachingBlockAllocator
)
block_with_prev
=
PrefixCachingBlock
(
prev_block
=
None
,
token_ids
=
token_ids
,
block_size
=
block_size
,
prefix_caching_allocator
=
mock_allocator
)
if
is_curr_block_full
:
# Expect hash since block is full.
assert
block_with_prev
.
content_hash
==
(
PrefixCachingBlock
.
hash_block_tokens
(
is_first_block
=
True
,
prev_block_hash
=
None
,
cur_block_token_ids
=
token_ids
))
else
:
# Do not expect hash since block is not full.
assert
block_with_prev
.
content_hash
is
None
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
10
)))
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
])
@
pytest
.
mark
.
parametrize
(
"is_curr_block_full"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"prev_block_has_hash"
,
[
True
,
False
])
def
test_nth_block_has_correct_content_hash
(
seed
:
int
,
block_size
:
int
,
is_curr_block_full
:
bool
,
prev_block_has_hash
:
bool
):
"""Verify a block which is not first in the sequence has the correct
hash.
"""
random
.
seed
(
seed
)
previous_block
=
MagicMock
(
spec
=
PrefixCachingBlock
)
prev_block_hash
=
random
.
randint
(
0
,
1000
)
previous_block
.
content_hash
=
(
prev_block_hash
if
prev_block_has_hash
else
None
)
num_to_fill
=
block_size
if
is_curr_block_full
else
random
.
randint
(
0
,
block_size
-
1
)
token_ids
=
list
(
range
(
num_to_fill
))
mock_allocator
=
MagicMock
(
spec
=
PrefixCachingBlockAllocator
)
block_with_prev
=
PrefixCachingBlock
(
prev_block
=
previous_block
,
token_ids
=
token_ids
,
block_size
=
block_size
,
prefix_caching_allocator
=
mock_allocator
,
)
if
is_curr_block_full
and
prev_block_has_hash
:
# Expect hash since block is full and previous block has hash.
assert
(
block_with_prev
.
content_hash
==
PrefixCachingBlock
.
hash_block_tokens
(
is_first_block
=
False
,
prev_block_hash
=
prev_block_hash
,
cur_block_token_ids
=
token_ids
))
else
:
# Do not expect hash since block is not full or the previous block
# does not have a hash.
assert
block_with_prev
.
content_hash
is
None
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
2
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
list
(
range
(
3
)))
@
pytest
.
mark
.
parametrize
(
"num_empty_trailing_blocks"
,
[
0
,
1
,
10
])
def
test_blocks_have_correct_hash_in_chain
(
block_size
:
int
,
num_tokens
:
int
,
num_empty_trailing_blocks
:
int
):
"""Create two chains of logical blocks with the same contents.
Assert the hashes are equal.
"""
random
.
seed
(
0
)
token_ids
=
[
random
.
randint
(
0
,
50_000
)
for
_
in
range
(
num_tokens
)]
first_chain
,
second_chain
=
[
TestPrefixCachingBlock
.
create_chain
(
block_size
=
block_size
,
token_ids
=
token_ids
,
num_empty_trailing_blocks
=
num_empty_trailing_blocks
)
for
_
in
range
(
2
)
]
for
first_chain_block
,
second_chain_block
in
zip
(
first_chain
,
second_chain
):
assert
(
first_chain_block
.
content_hash
==
second_chain_block
.
content_hash
)
if
not
first_chain
or
not
second_chain
:
assert
first_chain
==
second_chain
assert
num_tokens
==
0
@
staticmethod
def
create_chain
(
block_size
:
int
,
token_ids
:
List
[
int
],
num_empty_trailing_blocks
=
0
)
->
List
[
PrefixCachingBlock
]:
"""Helper method which creates a chain of blocks.
"""
blocks
=
[]
num_blocks
=
math
.
ceil
(
len
(
token_ids
)
/
block_size
)
+
num_empty_trailing_blocks
if
num_blocks
==
0
:
return
[]
allocator
=
MagicMock
(
spec
=
PrefixCachingBlockAllocator
)
prev_block
=
None
for
block_number
in
range
(
0
,
num_blocks
):
prev_block
=
PrefixCachingBlock
(
prev_block
=
prev_block
,
token_ids
=
[],
block_size
=
block_size
,
prefix_caching_allocator
=
allocator
,
)
tokens_to_append
=
token_ids
[
block_number
*
block_size
:(
block_number
+
1
)
*
block_size
]
if
tokens_to_append
:
prev_block
.
append_token_ids
(
tokens_to_append
)
blocks
.
append
(
prev_block
)
return
blocks
class
TestPrefixCachingBlockAllocator
:
@
staticmethod
def
create_allocate_lambda
(
allocate_type
:
str
,
allocator
:
BlockAllocator
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
]):
if
allocate_type
==
"immutable"
:
allocate_block
=
lambda
:
allocator
.
allocate_immutable
(
prev_block
=
prev_block
,
token_ids
=
token_ids
)
elif
allocate_type
==
"mutable"
:
allocate_block
=
lambda
:
allocator
.
allocate_mutable
(
prev_block
=
prev_block
)
else
:
raise
ValueError
()
return
allocate_block
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
])
def
test_allocate_mutable_ooms
(
num_blocks
:
int
,
block_size
:
int
):
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
allocate_block
=
TestPrefixCachingBlockAllocator
.
create_allocate_lambda
(
allocate_type
=
"mutable"
,
allocator
=
allocator
,
prev_block
=
None
,
token_ids
=
list
(
range
(
block_size
)),
)
[
allocate_block
()
for
_
in
range
(
num_blocks
)]
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocate_block
()
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
])
def
test_allocate_immutable_does_not_oom_single_hash
(
num_blocks
:
int
,
block_size
:
int
):
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
allocate_block
=
TestPrefixCachingBlockAllocator
.
create_allocate_lambda
(
allocate_type
=
"immutable"
,
allocator
=
allocator
,
prev_block
=
None
,
token_ids
=
list
(
range
(
block_size
)),
)
blocks
=
[
allocate_block
()
for
_
in
range
(
num_blocks
)]
# Expect no OOM. If these were mutable blocks, this would OOM.
non_oom_block
=
allocate_block
()
# Expect all blocks to have same physical block index.
for
block
in
blocks
:
assert
(
block
.
block_id
==
non_oom_block
.
block_id
)
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
])
def
test_allocate_immutable_ooms_many_hash
(
num_blocks
:
int
,
block_size
:
int
):
"""Consume all blocks using many different hashes/block content.
Do this by creating a sequence that is very long.
Expect next block to OOM.
"""
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
# Create token ids that will exhaust all blocks.
token_ids
=
list
(
range
(
num_blocks
*
block_size
))
chain
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
token_ids
,
allocator
=
allocator
,
)
# Expect allocation with unseen hash to fail.
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocator
.
allocate_immutable
(
prev_block
=
chain
[
-
1
],
token_ids
=
list
(
range
(
block_size
)))
# Expect mutable allocation to fail.
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocator
.
allocate_mutable
(
prev_block
=
chain
[
-
1
])
# Expect allocation of exact same chain to pass.
second_chain
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
token_ids
,
allocator
=
allocator
,
)
# Expect physical block indices to be the same in both chains.
assert
chain
and
second_chain
for
first_chain_block
,
second_chain_block
in
zip
(
chain
,
second_chain
):
assert
(
first_chain_block
.
block_id
==
second_chain_block
.
block_id
)
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1
,
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
1
,
16
])
def
test_free_prevents_oom
(
num_blocks
:
int
,
block_size
:
int
):
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
# Create token ids that will exhaust all blocks.
token_ids
=
list
(
range
(
num_blocks
*
block_size
))
chain
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
token_ids
,
allocator
=
allocator
,
)
# Expect mutable allocation to fail.
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocator
.
allocate_mutable
(
prev_block
=
None
)
block_to_free
=
chain
[
-
1
]
# Expect free/allocate loop to succeed many times.
for
i
in
range
(
100
):
block_id
=
block_to_free
.
block_id
allocator
.
free
(
block_to_free
)
assert
block_to_free
.
block_id
is
None
,
i
new_block
=
allocator
.
allocate_mutable
(
prev_block
=
None
)
assert
new_block
.
block_id
==
block_id
,
i
with
pytest
.
raises
(
BlockAllocator
.
NoFreeBlocksError
):
allocator
.
allocate_mutable
(
prev_block
=
None
)
block_to_free
=
new_block
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
20
)))
def
test_get_num_free_blocks
(
num_blocks
:
int
,
block_size
:
int
,
seed
:
int
):
random
.
seed
(
seed
)
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
num_blocks_to_consume
=
random
.
randint
(
1
,
num_blocks
-
1
)
# Create token ids that will exhaust all blocks.
token_ids
=
list
(
range
(
num_blocks_to_consume
*
block_size
))
chain
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
token_ids
,
allocator
=
allocator
,
)
# Free each block in chain, assert num free blocks includes new free
# block.
for
i
,
block
in
enumerate
(
chain
):
assert
allocator
.
get_num_free_blocks
()
==
(
num_blocks
-
num_blocks_to_consume
+
i
)
allocator
.
free
(
block
)
@
staticmethod
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
1024
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
list
(
range
(
20
)))
def
test_get_num_free_blocks_shared
(
num_blocks
:
int
,
block_size
:
int
,
seed
:
int
):
"""Verify sharing occurs by allocating two sequences that share prefixes
and incrementally freeing blocks.
"""
random
.
seed
(
seed
)
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
num_blocks_to_consume
=
random
.
randint
(
1
,
num_blocks
-
1
)
# Create token ids that will exhaust all blocks.
token_ids
=
list
(
range
(
num_blocks_to_consume
*
block_size
))
first_chain
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
token_ids
,
allocator
=
allocator
,
)
second_chain
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
token_ids
,
allocator
=
allocator
,
)
# Free each block in the first chain. Since all blocks are shared, the
# free count should stay constant.
for
i
,
block
in
enumerate
(
first_chain
):
assert
allocator
.
get_num_free_blocks
()
==
(
num_blocks
-
num_blocks_to_consume
)
allocator
.
free
(
block
)
# Free each block in the second chain. Since the refcount is now zero,
# the free count should increment with each free.
for
i
,
block
in
enumerate
(
second_chain
):
assert
allocator
.
get_num_free_blocks
()
==
(
num_blocks
-
num_blocks_to_consume
+
i
)
allocator
.
free
(
block
)
@
staticmethod
def
create_immutable_chain
(
block_size
:
int
,
token_ids
:
List
[
int
],
allocator
:
PrefixCachingBlockAllocator
,
)
->
List
[
PrefixCachingBlock
]:
"""Helper method which creates a chain of blocks.
"""
blocks
=
[]
num_blocks
=
math
.
ceil
(
len
(
token_ids
)
/
block_size
)
if
num_blocks
==
0
:
return
[]
prev_block
=
None
for
block_number
in
range
(
0
,
num_blocks
):
block_token_ids
=
token_ids
[
block_number
*
block_size
:(
block_number
+
1
)
*
block_size
]
prev_block
=
allocator
.
allocate_immutable
(
prev_block
=
prev_block
,
token_ids
=
block_token_ids
)
blocks
.
append
(
prev_block
)
return
blocks
tests/core/test_block_manager.py
View file @
14ccd94c
...
...
@@ -5,8 +5,9 @@ import pytest
from
vllm
import
SamplingParams
from
vllm.block
import
PhysicalTokenBlock
from
vllm.core.block_manager
import
(
AllocStatus
,
BlockSpaceManager
,
UncachedBlockAllocator
)
from
vllm.core.block_manager_v1
import
(
BlockSpaceManagerV1
,
UncachedBlockAllocator
)
from
vllm.core.interfaces
import
AllocStatus
from
vllm.sequence
import
Logprob
,
Sequence
,
SequenceGroup
,
SequenceStatus
from
vllm.utils
import
Device
...
...
@@ -63,10 +64,10 @@ def test_allocate():
block_size
=
4
num_cpu_blocks
=
4
num_gpu_blocks
=
4
block_manager
=
BlockSpaceManager
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
block_manager
=
BlockSpaceManager
V1
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
# Allocate same sequence group to all available gpu blocks.
for
i
in
range
(
num_gpu_blocks
):
...
...
@@ -77,10 +78,10 @@ def test_allocate():
# Allocate same sequence group to all available gpu blocks.
# Use watermark to reserve one gpu block.
block_manager
=
BlockSpaceManager
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
1
/
num_gpu_blocks
)
block_manager
=
BlockSpaceManager
V1
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
1
/
num_gpu_blocks
)
for
i
in
range
(
num_gpu_blocks
-
1
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
block_size
)
assert
block_manager
.
can_allocate
(
seq_group
)
...
...
@@ -92,10 +93,10 @@ def test_append_slot_single_seq():
block_size
=
4
num_cpu_blocks
=
4
num_gpu_blocks
=
4
block_manager
=
BlockSpaceManager
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
block_manager
=
BlockSpaceManager
V1
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
# Allocate single seq to gpu block.
prompt
,
seq_group
=
create_dummy_prompt
(
"1"
,
block_size
)
...
...
@@ -124,10 +125,10 @@ def test_append_slot_cow():
block_size
=
4
num_cpu_blocks
=
4
num_gpu_blocks
=
4
block_manager
=
BlockSpaceManager
(
block_size
=
block_size
,
num_cpu_blocks
=
num_cpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
watermark
=
0
)
block_manager
=
BlockSpaceManager
V1
(
block_size
=
block_size
,
num_cpu_blocks
=
num_cpu_blocks
,
num_gpu_blocks
=
num_gpu_blocks
,
watermark
=
0
)
# Allocate prompt to gpu block. There is one slot left in the block.
prompt
=
Sequence
(
seq_id
=
1
,
...
...
@@ -165,10 +166,10 @@ def test_fork():
block_size
=
4
num_cpu_blocks
=
4
num_gpu_blocks
=
4
block_manager
=
BlockSpaceManager
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
block_manager
=
BlockSpaceManager
V1
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
prompt
,
seq_group
=
create_dummy_prompt
(
"1"
,
block_size
-
1
,
...
...
@@ -192,10 +193,10 @@ def test_swap():
block_size
=
4
num_cpu_blocks
=
4
num_gpu_blocks
=
4
block_manager
=
BlockSpaceManager
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
block_manager
=
BlockSpaceManager
V1
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
prompt
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
block_size
-
1
)
prompt
.
status
=
SequenceStatus
.
WAITING
...
...
@@ -238,10 +239,10 @@ def test_free():
block_size
=
4
num_cpu_blocks
=
4
num_gpu_blocks
=
4
block_manager
=
BlockSpaceManager
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
block_manager
=
BlockSpaceManager
V1
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
prompt
,
seq_group
=
create_dummy_prompt
(
"1"
,
block_size
)
block_manager
.
allocate
(
seq_group
)
...
...
@@ -262,10 +263,10 @@ def test_reset():
block_size
=
4
num_cpu_blocks
=
4
num_gpu_blocks
=
4
block_manager
=
BlockSpaceManager
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
block_manager
=
BlockSpaceManager
V1
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
watermark
=
0
)
# Allocate same seq group on all available gpu blocks.
original_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
...
...
@@ -289,11 +290,11 @@ def test_sliding_window_multi_seq():
num_cpu_blocks
=
8
num_gpu_blocks
=
8
sliding_window
=
2
block_manager
=
BlockSpaceManager
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
sliding_window
=
sliding_window
,
watermark
=
0
)
block_manager
=
BlockSpaceManager
V1
(
block_size
,
num_cpu_blocks
,
num_gpu_blocks
,
sliding_window
=
sliding_window
,
watermark
=
0
)
assert
block_manager
.
get_num_free_gpu_blocks
()
==
num_gpu_blocks
...
...
tests/core/test_scheduler.py
View file @
14ccd94c
...
...
@@ -13,7 +13,7 @@ from .utils import create_dummy_prompt
def
test_scheduler_add_seq_group
():
block_size
=
4
scheduler_config
=
SchedulerConfig
(
100
,
64
,
1
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
cache_dtype
=
"auto"
)
cache_config
.
num_cpu_blocks
=
4
cache_config
.
num_gpu_blocks
=
4
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
None
)
...
...
tests/core/utils.py
View file @
14ccd94c
...
...
@@ -2,7 +2,7 @@ import time
from
typing
import
Tuple
from
vllm
import
SamplingParams
from
vllm.sequence
import
Sequence
,
SequenceGroup
from
vllm.sequence
import
Logprob
,
Sequence
,
SequenceGroup
def
create_dummy_prompt
(
...
...
@@ -23,5 +23,42 @@ def create_dummy_prompt(
return
prompt
,
seq_group
def
create_seq_group
(
seq_prompt_lens
=
1024
,
seq_output_lens
=
(
128
,
),
request_id
=
'0'
,
seq_id_start
=
0
,
)
->
SequenceGroup
:
assert
len
(
seq_output_lens
)
>
0
prompt_token_ids
=
[
0
]
*
seq_prompt_lens
seqs
=
[]
for
seq_id_offset
,
output_len
in
enumerate
(
seq_output_lens
):
seq
=
Sequence
(
seq_id
=
seq_id_start
+
seq_id_offset
,
prompt
=
""
,
prompt_token_ids
=
prompt_token_ids
,
block_size
=
16
,
)
for
i
in
range
(
output_len
):
seq
.
append_token_id
(
token_id
=
i
,
logprobs
=
{
i
:
Logprob
(
0.0
)},
)
seqs
.
append
(
seq
)
seq_group
=
SequenceGroup
(
request_id
=
request_id
,
seqs
=
seqs
,
sampling_params
=
SamplingParams
(),
arrival_time
=
time
.
time
(),
)
return
seq_group
def
round_up_to_next_block
(
seq_len
:
int
,
block_size
:
int
)
->
int
:
return
(
seq_len
+
block_size
-
1
)
//
block_size
tests/prefix_caching/test_prefix_caching.py
View file @
14ccd94c
...
...
@@ -4,7 +4,7 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import
pytest
from
vllm.core.block_manager
import
CachedBlockAllocator
from
vllm.core.block_manager
_v1
import
CachedBlockAllocator
from
vllm.utils
import
Device
...
...
vllm/config.py
View file @
14ccd94c
...
...
@@ -324,6 +324,8 @@ class CacheConfig:
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
forced_num_gpu_blocks: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
"""
def
__init__
(
...
...
@@ -332,12 +334,14 @@ class CacheConfig:
gpu_memory_utilization
:
float
,
swap_space
:
int
,
cache_dtype
:
str
,
forced_num_gpu_blocks
:
Optional
[
int
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
enable_prefix_caching
:
bool
=
False
,
)
->
None
:
self
.
block_size
=
block_size
self
.
gpu_memory_utilization
=
gpu_memory_utilization
self
.
swap_space_bytes
=
swap_space
*
_GB
self
.
forced_num_gpu_blocks
=
forced_num_gpu_blocks
self
.
cache_dtype
=
cache_dtype
self
.
sliding_window
=
sliding_window
self
.
enable_prefix_caching
=
enable_prefix_caching
...
...
@@ -528,6 +532,7 @@ class SchedulerConfig:
and generated text).
delay_factor: Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.
use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
"""
def
__init__
(
...
...
@@ -535,6 +540,7 @@ class SchedulerConfig:
max_num_batched_tokens
:
Optional
[
int
],
max_num_seqs
:
int
,
max_model_len
:
int
,
use_v2_block_manager
:
bool
=
False
,
delay_factor
:
float
=
0.0
,
)
->
None
:
if
max_num_batched_tokens
is
not
None
:
...
...
@@ -546,6 +552,7 @@ class SchedulerConfig:
self
.
max_num_seqs
=
max_num_seqs
self
.
max_model_len
=
max_model_len
self
.
delay_factor
=
delay_factor
self
.
use_v2_block_manager
=
use_v2_block_manager
self
.
_verify_args
()
def
_verify_args
(
self
)
->
None
:
...
...
vllm/core/block/block_table.py
0 → 100644
View file @
14ccd94c
from
typing
import
List
,
Optional
from
vllm.core.block.interfaces
import
Block
,
DeviceAwareBlockAllocator
from
vllm.utils
import
Device
,
cdiv
,
chunk_list
class
BlockTable
:
"""A class to manage blocks for a specific sequence.
The BlockTable maps a sequence of tokens to a list of blocks, where each
block represents a contiguous memory allocation for a portion of the
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
responsible for allocating and freeing memory for the blocks.
Args:
block_size (int): The maximum number of tokens that can be stored in a
single block.
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
single block.
_allocator (DeviceAwareBlockAllocator): The block allocator used to
manage memory for the blocks.
_blocks (Optional[List[Block]]): The list of blocks managed by this
BlockTable.
_num_full_slots (int): The number of tokens currently stored in the
blocks.
"""
def
__init__
(
self
,
block_size
:
int
,
block_allocator
:
DeviceAwareBlockAllocator
,
_blocks
:
Optional
[
List
[
Block
]]
=
None
,
):
self
.
_block_size
=
block_size
self
.
_allocator
=
block_allocator
self
.
_blocks
:
Optional
[
List
[
Block
]]
=
_blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
self
.
_num_full_slots
=
len
(
self
.
_get_all_token_ids
())
@
staticmethod
def
get_num_required_blocks
(
token_ids
:
List
[
int
],
block_size
:
int
)
->
int
:
"""Calculates the minimum number of blocks required to store a given
sequence of token IDs.
This assumes worst-case scenario, where every block requires a new
allocation (e.g. ignoring prefix caching).
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
block_size (int): The maximum number of tokens that can be stored in
a single block.
Returns:
int: The minimum number of blocks required to store the given
sequence of token IDs.
"""
return
cdiv
(
len
(
token_ids
),
block_size
)
def
allocate
(
self
,
token_ids
:
List
[
int
],
device
:
Device
=
Device
.
GPU
)
->
None
:
"""Allocates memory blocks for storing the given sequence of token IDs.
This method allocates the required number of blocks to store the given
sequence of token IDs.
Args:
token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
"""
assert
not
self
.
_is_allocated
assert
token_ids
self
.
_blocks
=
self
.
_allocate_blocks_for_token_ids
(
prev_block
=
None
,
token_ids
=
token_ids
,
device
=
device
)
self
.
_num_full_slots
=
len
(
token_ids
)
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
This method appends the given sequence of token IDs to the existing
blocks in the BlockTable. If there is not enough space in the existing
blocks, new blocks are allocated using the `ensure_num_empty_slots`
method to accommodate the additional tokens.
The token IDs are divided into chunks of size `block_size` (except for
the first chunk, which may be smaller), and each chunk is appended to a
separate block.
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert
self
.
_is_allocated
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
))
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
first_chunk_size
=
self
.
_block_size
-
(
self
.
_num_full_slots
%
self
.
_block_size
)
token_blocks
=
[
token_ids
[:
first_chunk_size
]]
+
chunk_list
(
token_ids
[
first_chunk_size
:],
self
.
_block_size
)
for
block
,
token_block
in
zip
(
blocks
,
token_blocks
):
block
.
append_token_ids
(
token_block
)
self
.
_num_full_slots
+=
len
(
token_ids
)
def
ensure_num_empty_slots
(
self
,
num_empty_slots
:
int
)
->
None
:
"""Ensures that the BlockTable has at least the specified number of
empty slots available.
This method checks if the BlockTable has enough empty slots (i.e.,
available space) to accommodate the requested number of tokens. If not,
it allocates additional blocks on the GPU to ensure that the required
number of empty slots is available.
Args:
num_empty_slots (int): The minimum number of empty slots required.
"""
# Currently the block table only supports
# appending tokens to GPU blocks.
device
=
Device
.
GPU
assert
self
.
_is_allocated
if
self
.
_num_empty_slots
>=
num_empty_slots
:
return
slots_to_allocate
=
num_empty_slots
-
self
.
_num_empty_slots
blocks_to_allocate
=
cdiv
(
slots_to_allocate
,
self
.
_block_size
)
for
_
in
range
(
blocks_to_allocate
):
self
.
_blocks
.
append
(
self
.
_allocator
.
allocate_mutable
(
prev_block
=
self
.
_blocks
[
-
1
],
device
=
device
))
def
fork
(
self
)
->
"BlockTable"
:
"""Creates a new BlockTable instance with a copy of the blocks from the
current instance.
This method creates a new BlockTable instance with the same block size,
block allocator, and a copy of the blocks from the current instance. The
new BlockTable has its own independent set of blocks, but shares the
same underlying memory allocation with the original BlockTable.
Returns:
BlockTable: A new BlockTable instance with a copy of the blocks from
the current instance.
"""
assert
self
.
_is_allocated
forked_blocks
=
self
.
_allocator
.
fork
(
self
.
_blocks
[
-
1
])
return
BlockTable
(
block_size
=
self
.
_block_size
,
block_allocator
=
self
.
_allocator
,
_blocks
=
forked_blocks
,
)
def
free
(
self
)
->
None
:
"""Frees the memory occupied by the blocks in the BlockTable.
This method iterates over all the blocks in the `_blocks` list and calls
the `free` method of the `_allocator` object to release the memory
occupied by each block. After freeing all the blocks, the `_blocks` list
is set to `None`.
"""
assert
self
.
_is_allocated
for
block
in
self
.
_blocks
:
self
.
_allocator
.
free
(
block
)
self
.
_blocks
=
None
@
property
def
physical_block_ids
(
self
)
->
List
[
int
]:
"""Returns a list of physical block indices for the blocks in the
BlockTable.
This property returns a list of integers, where each integer represents
the physical block index of a corresponding block in the `_blocks` list.
The physical block index is a unique identifier for the memory location
occupied by the block.
Returns:
List[int]: A list of physical block indices for the blocks in the
BlockTable.
"""
assert
self
.
_is_allocated
return
[
block
.
block_id
for
block
in
self
.
_blocks
]
def
_allocate_blocks_for_token_ids
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
List
[
Block
]:
blocks
=
[]
for
block_token_ids
in
chunk_list
(
token_ids
,
self
.
_block_size
):
if
len
(
block_token_ids
)
==
self
.
_block_size
:
# If the block is full, create an immutable block.
prev_block
=
self
.
_allocator
.
allocate_immutable
(
prev_block
,
token_ids
=
block_token_ids
,
device
=
device
)
else
:
# Else, partially fill a mutable block with token ids.
prev_block
=
self
.
_allocator
.
allocate_mutable
(
prev_block
=
prev_block
,
device
=
device
)
prev_block
.
append_token_ids
(
block_token_ids
)
blocks
.
append
(
prev_block
)
return
blocks
def
_get_all_token_ids
(
self
)
->
List
[
int
]:
# NOTE: This function is O(seq_len); use sparingly.
token_ids
=
[]
if
not
self
.
_is_allocated
:
return
token_ids
for
block
in
self
.
_blocks
:
token_ids
.
extend
(
block
.
token_ids
)
return
token_ids
@
property
def
_is_allocated
(
self
)
->
bool
:
return
self
.
_blocks
is
not
None
@
property
def
_num_empty_slots
(
self
)
->
int
:
assert
self
.
_is_allocated
return
len
(
self
.
_blocks
)
*
self
.
_block_size
-
self
.
_num_full_slots
@
property
def
num_full_slots
(
self
)
->
int
:
"""Returns the total number of tokens currently stored in the
BlockTable.
Returns:
int: The total number of tokens currently stored in the BlockTable.
"""
return
self
.
_num_full_slots
vllm/core/block/common.py
0 → 100644
View file @
14ccd94c
from
collections
import
defaultdict
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
BlockId
=
int
RefCount
=
int
class
RefCounter
:
"""A class for managing reference counts for a set of block indices.
The RefCounter class maintains a dictionary that maps block indices to their
corresponding reference counts. It provides methods to increment, decrement,
and retrieve the reference count for a given block index.
Args:
all_block_indices (Iterable[BlockId]): An iterable of block indices
to initialize the reference counter with.
"""
def
__init__
(
self
,
all_block_indices
:
Iterable
[
BlockId
]):
deduped
=
set
(
all_block_indices
)
self
.
_refcounts
:
Dict
[
BlockId
,
RefCount
]
=
{
index
:
0
for
index
in
deduped
}
def
incr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
assert
block_id
in
self
.
_refcounts
pre_incr_refcount
=
self
.
_refcounts
[
block_id
]
assert
pre_incr_refcount
>=
0
post_incr_refcount
=
pre_incr_refcount
+
1
self
.
_refcounts
[
block_id
]
=
post_incr_refcount
return
post_incr_refcount
def
decr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
assert
block_id
in
self
.
_refcounts
refcount
=
self
.
_refcounts
[
block_id
]
assert
refcount
>
0
refcount
-=
1
self
.
_refcounts
[
block_id
]
=
refcount
return
refcount
def
get
(
self
,
block_id
:
BlockId
)
->
RefCount
:
assert
block_id
in
self
.
_refcounts
return
self
.
_refcounts
[
block_id
]
def
as_readonly
(
self
)
->
"ReadOnlyRefCounter"
:
return
ReadOnlyRefCounter
(
self
)
class
ReadOnlyRefCounter
:
"""A read-only view of the RefCounter class.
The ReadOnlyRefCounter class provides a read-only interface to access the
reference counts maintained by a RefCounter instance. It does not allow
modifications to the reference counts.
Args:
refcounter (RefCounter): The RefCounter instance to create a read-only
view for.
"""
def
__init__
(
self
,
refcounter
:
RefCounter
):
self
.
_refcounter
=
refcounter
def
incr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
raise
ValueError
(
"Incr not allowed"
)
def
decr
(
self
,
block_id
:
BlockId
)
->
RefCount
:
raise
ValueError
(
"Decr not allowed"
)
def
get
(
self
,
block_id
:
BlockId
)
->
RefCount
:
return
self
.
_refcounter
.
get
(
block_id
)
class
CopyOnWriteTracker
:
"""A class for tracking and managing copy-on-write operations for blocks.
The CopyOnWriteTracker class maintains a mapping of source block indices to
their corresponding copy-on-write destination block indices. It works in
conjunction with a RefCounter and a BlockAllocator to handle reference
counting and block allocation.
Args:
refcounter (RefCounter): The reference counter used to track block
reference counts.
allocator (BlockAllocator): The block allocator used to allocate and
free blocks.
"""
def
__init__
(
self
,
refcounter
:
RefCounter
,
allocator
:
BlockAllocator
,
):
self
.
_copy_on_writes
=
defaultdict
(
list
)
self
.
_refcounter
=
refcounter
self
.
_allocator
=
allocator
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
This method checks the reference count of the given block. If the
reference count is greater than 1, indicating that the block is shared,
a copy-on-write operation is performed. The original block is freed,
and a new block is allocated with the same content. The new block index
is returned.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
block_id
=
block
.
block_id
if
block_id
is
None
:
return
block_id
refcount
=
self
.
_refcounter
.
get
(
block_id
)
assert
refcount
!=
0
if
refcount
>
1
:
src_block_id
=
block_id
# Decrement refcount of the old block.
self
.
_allocator
.
free
(
block
)
# Allocate a fresh new block.
block_id
=
self
.
_allocator
.
allocate_mutable
(
prev_block
=
block
.
prev_block
).
block_id
# Track src/dst copy.
self
.
_copy_on_writes
[
src_block_id
].
append
(
block_id
)
return
block_id
def
clear_cows
(
self
)
->
Dict
[
BlockId
,
List
[
BlockId
]]:
"""Clears the copy-on-write tracking information and returns the current
state.
This method returns a dictionary mapping source block indices to lists
of destination block indices for the current copy-on-write operations.
It then clears the internal tracking information.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices for the
current copy-on-write operations.
"""
cows
=
dict
(
self
.
_copy_on_writes
)
self
.
_copy_on_writes
.
clear
()
return
cows
def
get_all_blocks_recursively
(
last_block
:
Block
)
->
List
[
Block
]:
"""Retrieves all the blocks in a sequence starting from the last block.
This function recursively traverses the sequence of blocks in reverse order,
starting from the given last block, and returns a list of all the blocks in
the sequence.
Args:
last_block (Block): The last block in the sequence.
Returns:
List[Block]: A list of all the blocks in the sequence, in the order they
appear.
"""
def
recurse
(
block
:
Block
,
lst
:
List
[
Block
])
->
None
:
if
block
.
prev_block
is
not
None
:
recurse
(
block
.
prev_block
,
lst
)
lst
.
append
(
block
)
all_blocks
=
[]
recurse
(
last_block
,
all_blocks
)
return
all_blocks
vllm/core/block/cpu_gpu_block_allocator.py
0 → 100644
View file @
14ccd94c
from
typing
import
Dict
,
List
,
Optional
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
DeviceAwareBlockAllocator
)
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
from
vllm.core.block.prefix_caching_block
import
PrefixCachingBlockAllocator
from
vllm.utils
import
Device
class
CpuGpuBlockAllocator
(
DeviceAwareBlockAllocator
):
"""A block allocator that can allocate blocks on both CPU and GPU memory.
This class implements the `DeviceAwareBlockAllocator` interface and provides
functionality for allocating and managing blocks of memory on both CPU and
GPU devices.
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
blocks, and allows for allocation, deallocation, forking, and swapping of
blocks across these memory pools.
"""
@
staticmethod
def
create
(
allocator_type
:
str
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
block_size
:
int
,
)
->
DeviceAwareBlockAllocator
:
"""Creates a CpuGpuBlockAllocator instance with the specified
configuration.
This static method creates and returns a CpuGpuBlockAllocator instance
based on the provided parameters. It initializes the CPU and GPU block
allocators with the specified number of blocks, block size, and
allocator type.
Args:
allocator_type (str): The type of block allocator to use for CPU
and GPU blocks. Currently supported values are "naive" and
"prefix_caching".
num_gpu_blocks (int): The number of blocks to allocate for GPU
memory.
num_cpu_blocks (int): The number of blocks to allocate for CPU
memory.
block_size (int): The size of each block in number of tokens.
Returns:
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
specified configuration.
Notes:
- The block IDs are assigned contiguously, with GPU block IDs coming
before CPU block IDs.
"""
block_ids
=
list
(
range
(
num_gpu_blocks
+
num_cpu_blocks
))
gpu_block_ids
=
block_ids
[:
num_gpu_blocks
]
cpu_block_ids
=
block_ids
[
num_gpu_blocks
:]
if
allocator_type
==
"naive"
:
gpu_allocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
num_blocks
=
num_gpu_blocks
,
block_size
=
block_size
,
block_ids
=
gpu_block_ids
,
)
cpu_allocator
=
NaiveBlockAllocator
(
create_block
=
NaiveBlock
,
num_blocks
=
num_cpu_blocks
,
block_size
=
block_size
,
block_ids
=
cpu_block_ids
,
)
elif
allocator_type
==
"prefix_caching"
:
gpu_allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_gpu_blocks
,
block_size
=
block_size
,
block_ids
=
gpu_block_ids
,
)
cpu_allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_cpu_blocks
,
block_size
=
block_size
,
block_ids
=
cpu_block_ids
,
)
else
:
raise
ValueError
(
f
"Unknown allocator type
{
allocator_type
=
}
"
)
return
CpuGpuBlockAllocator
(
cpu_block_allocator
=
cpu_allocator
,
gpu_block_allocator
=
gpu_allocator
,
)
def
__init__
(
self
,
cpu_block_allocator
:
BlockAllocator
,
gpu_block_allocator
:
BlockAllocator
,
):
assert
not
(
cpu_block_allocator
.
all_block_ids
&
gpu_block_allocator
.
all_block_ids
),
"cpu and gpu block allocators can't have intersection of block ids"
self
.
_allocators
=
{
Device
.
CPU
:
cpu_block_allocator
,
Device
.
GPU
:
gpu_block_allocator
,
}
self
.
_block_ids_to_allocator
=
{}
for
_
,
allocator
in
self
.
_allocators
.
items
():
for
block_id
in
allocator
.
all_block_ids
:
self
.
_block_ids_to_allocator
[
block_id
]
=
allocator
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
"""Allocates a new mutable block on the specified device.
Args:
prev_block (Optional[Block]): The previous block to in the sequence.
Used for prefix hashing.
device (Device): The device on which to allocate the new block.
Returns:
Block: The newly allocated mutable block.
"""
return
self
.
_allocators
[
device
].
allocate_mutable
(
prev_block
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
Used for prefix hashing.
token_ids (List[int]): The list of token IDs to be stored in the new
block.
device (Device): The device on which to allocate the new block.
Returns:
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return
self
.
_allocators
[
device
].
allocate_immutable
(
prev_block
,
token_ids
)
def
free
(
self
,
block
:
Block
)
->
None
:
"""Frees the memory occupied by the given block.
Args:
block (Block): The block to be freed.
"""
allocator
=
self
.
_block_ids_to_allocator
[
block
.
block_id
]
return
allocator
.
free
(
block
)
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
"""
allocator
=
self
.
_block_ids_to_allocator
[
last_block
.
block_id
]
return
allocator
.
fork
(
last_block
)
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
"""Returns the number of free blocks available on the specified device.
Args:
device (Device): The device for which to query the number of free
blocks.
Returns:
int: The number of free blocks available on the specified device.
"""
return
self
.
_allocators
[
device
].
get_num_free_blocks
()
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
"""Clears the copy-on-write (CoW) state and returns the mapping of
source to destination block IDs.
Returns:
Dict[int, List[int]]: A dictionary mapping source block IDs to lists
of destination block IDs.
"""
# CoW only supported on GPU
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
clear_copy_on_writes
()
def
mark_blocks_as_computed
(
self
)
->
None
:
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
()
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
get_common_computed_block_ids
(
seq_block_ids
)
def
all_block_ids
(
self
)
->
frozenset
[
int
]:
return
frozenset
(
self
.
_block_ids_to_allocator
.
keys
())
vllm/core/block/interfaces.py
0 → 100644
View file @
14ccd94c
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
typing
import
Dict
,
List
,
Optional
,
Protocol
from
vllm.utils
import
Device
class
Block
(
ABC
):
@
abstractmethod
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
pass
@
abstractproperty
def
block_id
(
self
)
->
Optional
[
int
]:
pass
@
abstractproperty
def
token_ids
(
self
)
->
List
[
int
]:
pass
@
abstractproperty
def
num_empty_slots
(
self
)
->
int
:
pass
@
abstractproperty
def
is_full
(
self
)
->
bool
:
pass
@
abstractproperty
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
pass
class
Factory
(
Protocol
):
@
abstractmethod
def
__call__
(
self
,
prev_block
:
Optional
[
"Block"
],
token_ids
:
List
[
int
],
block_size
:
int
,
allocator
:
"BlockAllocator"
,
block_id
:
Optional
[
int
]
=
None
,
)
->
"Block"
:
pass
class
BlockAllocator
(
ABC
):
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
pass
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
pass
@
abstractmethod
def
free
(
self
,
block
:
Block
)
->
None
:
pass
@
abstractmethod
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
pass
@
abstractmethod
def
get_num_free_blocks
(
self
)
->
int
:
pass
@
abstractproperty
def
all_block_ids
(
self
)
->
frozenset
[
int
]:
pass
@
abstractmethod
def
clear_copy_on_writes
(
self
)
->
Dict
[
int
,
List
[
int
]]:
pass
@
abstractmethod
def
mark_blocks_as_computed
(
self
)
->
None
:
pass
@
abstractmethod
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
class
NoFreeBlocksError
(
ValueError
):
pass
class
DeviceAwareBlockAllocator
(
BlockAllocator
):
@
abstractmethod
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
pass
@
abstractmethod
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
device
:
Device
)
->
Block
:
pass
@
abstractmethod
def
get_num_free_blocks
(
self
,
device
:
Device
)
->
int
:
pass
vllm/core/block/naive_block.py
0 → 100644
View file @
14ccd94c
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
RefCounter
,
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
BlockId
=
int
Refcount
=
int
class
NaiveBlockAllocator
(
BlockAllocator
):
"""A simple block allocator that manages blocks of memory without prefix
caching.
Args:
create_block (Block.Factory): A factory function for creating new
blocks. This is used when a NaiveBlockAllocator is composed within
a prefix caching allocator -- the naive block allocator must
construct prefix caching blocks (but shouldn't know anything else
about them).
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids (Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
def
__init__
(
self
,
create_block
:
Block
.
Factory
,
num_blocks
:
int
,
block_size
:
int
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
):
if
block_ids
is
None
:
block_ids
=
range
(
num_blocks
)
self
.
_free_block_indices
:
Set
[
BlockId
]
=
set
(
block_ids
)
self
.
_all_block_indices
=
frozenset
(
block_ids
)
assert
len
(
self
.
_all_block_indices
)
==
num_blocks
self
.
_refcounter
=
RefCounter
(
all_block_indices
=
self
.
_free_block_indices
)
self
.
_create_block
=
create_block
self
.
_block_size
=
block_size
self
.
_cow_tracker
=
CopyOnWriteTracker
(
refcounter
=
self
.
_refcounter
.
as_readonly
(),
allocator
=
self
,
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
"""Allocates a new immutable block with the given token IDs, linked to
the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
token_ids (List[int]): The token IDs to be stored in the new block.
Returns:
Block: The newly allocated immutable block.
"""
block
=
self
.
allocate_mutable
(
prev_block
=
prev_block
)
block
.
append_token_ids
(
token_ids
)
return
block
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
])
->
Block
:
"""Allocates a new mutable block, linked to the previous block.
Args:
prev_block (Optional[Block]): The previous block in the sequence. If
None, then the block to be allocated is the first block in the
sequence.
Returns:
Block: The newly allocated mutable block.
"""
block_id
=
self
.
_allocate_new_block_id
()
return
self
.
_create_block
(
prev_block
=
prev_block
,
token_ids
=
[],
block_id
=
block_id
,
block_size
=
self
.
_block_size
,
allocator
=
self
,
)
def
free
(
self
,
block
:
Block
)
->
None
:
self
.
_free_block_id
(
block
.
block_id
)
# Mark the block as having no allocation.
block
.
block_id
=
None
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks
=
get_all_blocks_recursively
(
last_block
)
forked_blocks
=
[]
prev_block
=
None
for
block
in
source_blocks
:
# Increment refcount for each block.
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
assert
refcount
!=
1
,
"can't fork free'd block"
forked_blocks
.
append
(
self
.
_create_block
(
prev_block
=
prev_block
,
token_ids
=
block
.
token_ids
,
block_id
=
block
.
block_id
,
block_size
=
self
.
_block_size
,
allocator
=
self
,
))
prev_block
=
forked_blocks
[
-
1
]
return
forked_blocks
def
get_num_free_blocks
(
self
)
->
int
:
return
len
(
self
.
_free_block_indices
)
def
_allocate_new_block_id
(
self
)
->
BlockId
:
if
not
self
.
_free_block_indices
:
raise
BlockAllocator
.
NoFreeBlocksError
()
block_id
=
next
(
iter
(
self
.
_free_block_indices
))
self
.
_refcounter
.
incr
(
block_id
)
self
.
_free_block_indices
.
remove
(
block_id
)
return
block_id
def
_free_block_id
(
self
,
block_id
:
BlockId
)
->
None
:
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
if
refcount
==
0
:
self
.
_free_block_indices
.
add
(
block_id
)
@
property
def
refcounter
(
self
):
return
self
.
_refcounter
@
property
def
all_block_ids
(
self
):
return
self
.
_all_block_indices
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
def
clear_copy_on_writes
(
self
)
->
Dict
[
BlockId
,
List
[
BlockId
]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices.
"""
return
self
.
_cow_tracker
.
clear_cows
()
def
mark_blocks_as_computed
(
self
)
->
None
:
"""Mark blocks as computed, used in prefix caching.
Since the naive allocator does not implement prefix caching, we do
nothing.
"""
pass
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Determine blocks that can be skipped in prefill.
Since the naive allocator does not support prefix caching, always return
an empty list.
"""
return
[]
class
NaiveBlock
(
Block
):
"""An implementation of the Block class that does not support prefix
caching.
The NaiveBlock class represents a block of token IDs with a fixed size. It
provides methods for appending token IDs to the block and manages copy-on
-write operations when necessary.
Args:
prev_block (Block): The previous block in the sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
allocator (BlockAllocator): The block allocator associated with this
block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None, which means no allocation has been
made.
_cow_target (Optional[Block], optional): The copy-on-write target block.
If not provided, it defaults to self.
"""
def
__init__
(
self
,
prev_block
:
Block
,
token_ids
:
List
[
int
],
block_size
:
int
,
allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
_cow_target
:
Optional
[
Block
]
=
None
):
self
.
_token_ids
=
[]
self
.
_block_size
=
block_size
self
.
_prev_block
=
prev_block
self
.
_block_id
=
block_id
self
.
_allocator
=
allocator
self
.
_cow_target
=
_cow_target
if
_cow_target
is
not
None
else
self
self
.
_append_token_ids_no_cow
(
token_ids
)
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
"""Appends the given token IDs to the block, instructing the allocator
to perform a copy-on-write if necessary.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
self
.
_append_token_ids_no_cow
(
token_ids
)
if
self
.
_block_id
is
not
None
:
self
.
_block_id
=
(
self
.
_allocator
.
cow_block_if_not_appendable
(
self
.
_cow_target
))
def
_append_token_ids_no_cow
(
self
,
token_ids
:
List
[
int
])
->
None
:
assert
self
.
num_empty_slots
>=
len
(
token_ids
)
self
.
_token_ids
.
extend
(
token_ids
)
@
property
def
block_id
(
self
)
->
Optional
[
int
]:
return
self
.
_block_id
@
block_id
.
setter
def
block_id
(
self
,
value
:
Optional
[
int
])
->
None
:
self
.
_block_id
=
value
@
property
def
is_full
(
self
)
->
bool
:
return
self
.
num_empty_slots
==
0
@
property
def
num_empty_slots
(
self
)
->
int
:
return
self
.
_block_size
-
len
(
self
.
_token_ids
)
@
property
def
token_ids
(
self
)
->
List
[
int
]:
return
self
.
_token_ids
def
block_size
(
self
)
->
int
:
return
self
.
_block_size
@
property
def
prev_block
(
self
)
->
Optional
[
"Block"
]:
return
self
.
_prev_block
vllm/core/block/prefix_caching_block.py
0 → 100644
View file @
14ccd94c
"""Token blocks."""
from
itertools
import
takewhile
from
os.path
import
commonprefix
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
vllm.core.block.common
import
(
CopyOnWriteTracker
,
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.naive_block
import
NaiveBlock
,
NaiveBlockAllocator
PrefixHash
=
int
BlockId
=
int
class
PrefixCachingBlockAllocator
(
BlockAllocator
):
"""A block allocator that implements prefix caching.
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
content hash. It reuses blocks with the same content hash to avoid redundant
memory allocation. The allocator also supports copy-on-write operations.
Args:
num_blocks (int): The total number of blocks to manage.
block_size (int): The size of each block in tokens.
block_ids(Optional[Iterable[int]], optional): An optional iterable of
block IDs. If not provided, block IDs will be assigned sequentially
from 0 to num_blocks - 1.
"""
# TODO last access time / evictor integration
def
__init__
(
self
,
num_blocks
:
int
,
block_size
:
int
,
block_ids
:
Optional
[
Iterable
[
int
]]
=
None
,
):
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash will be in this dict, even if they have refcount 0.
self
.
_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
# A mapping of prefix hash to block index. All blocks which have a
# prefix hash AND refcount 0 will be in this dict. Thus, it is a subset
# of self._cached_blocks.
self
.
_unused_cached_blocks
:
Dict
[
PrefixHash
,
BlockId
]
=
{}
# An allocator for blocks that do not have prefix hashes.
self
.
_hashless_allocator
=
NaiveBlockAllocator
(
create_block
=
self
.
_create_block
,
num_blocks
=
num_blocks
,
block_size
=
block_size
,
block_ids
=
block_ids
,
)
self
.
_block_size
=
block_size
# We share the refcounter between allocators. This allows us to promote
# blocks originally allocated in the hashless allocator to immutable
# blocks.
self
.
_refcounter
=
self
.
_hashless_allocator
.
refcounter
self
.
_cow_tracker
=
CopyOnWriteTracker
(
refcounter
=
self
.
_refcounter
.
as_readonly
(),
allocator
=
self
,
)
# Implements Block.Factory.
def
_create_block
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
],
block_size
:
int
,
allocator
:
BlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
)
->
Block
:
# Bind block to self.
allocator
=
self
return
PrefixCachingBlock
(
prev_block
=
prev_block
,
token_ids
=
token_ids
,
block_size
=
block_size
,
block_id
=
block_id
,
prefix_caching_allocator
=
allocator
,
)
def
allocate_immutable
(
self
,
prev_block
:
Optional
[
Block
],
token_ids
:
List
[
int
])
->
Block
:
"""Allocates an immutable block with the given token IDs, reusing cached
blocks if possible.
Args:
prev_block (Optional[Block]): The previous block in the sequence.
token_ids (List[int]): The token IDs to be stored in the block.
Returns:
Block: The allocated immutable block.
"""
assert_prefix_caching_block_or_none
(
prev_block
)
block
=
self
.
_create_block
(
prev_block
=
prev_block
,
token_ids
=
token_ids
,
block_size
=
self
.
_block_size
,
allocator
=
self
,
)
assert
block
.
content_hash
is
not
None
cached_block_id
=
self
.
_cached_blocks
.
get
(
block
.
content_hash
,
None
)
if
cached_block_id
is
not
None
:
block
.
block_id
=
cached_block_id
self
.
_incr_refcount_cached_block
(
block
.
content_hash
,
block
.
block_id
)
return
block
block
=
self
.
allocate_mutable
(
prev_block
)
block
.
append_token_ids
(
token_ids
)
assert
block
.
content_hash
is
not
None
# TODO computed bit
return
block
def
allocate_mutable
(
self
,
prev_block
:
Block
)
->
Block
:
"""Allocates a mutable block. If there are no free blocks, this will
evict unused cached blocks.
Args:
prev_block (Block): The previous block in the sequence.
Returns:
Block: The allocated mutable block.
"""
assert_prefix_caching_block_or_none
(
prev_block
)
try
:
return
self
.
_hashless_allocator
.
allocate_mutable
(
prev_block
=
prev_block
)
except
BlockAllocator
.
NoFreeBlocksError
:
# We must check the unused cached blocks before raising OOM.
pass
if
self
.
_unused_cached_blocks
:
# TODO policy for selecting block to remove
content_hash_to_evict
=
next
(
iter
(
self
.
_unused_cached_blocks
))
# Clear content hash mapping; the block will be overwritten.
del
self
.
_cached_blocks
[
content_hash_to_evict
]
block_id
=
self
.
_unused_cached_blocks
.
pop
(
content_hash_to_evict
)
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
assert
refcount
==
1
block
=
self
.
_create_block
(
prev_block
=
prev_block
,
token_ids
=
[],
block_size
=
self
.
_block_size
,
allocator
=
self
,
block_id
=
block_id
,
)
assert
block
.
content_hash
is
None
return
block
# No block available in hashless allocator, nor in unused cache blocks.
raise
BlockAllocator
.
NoFreeBlocksError
()
def
_incr_refcount_cached_block
(
self
,
content_hash
:
int
,
block_id
:
BlockId
)
->
None
:
refcount
=
self
.
_refcounter
.
incr
(
block_id
)
if
refcount
==
1
:
assert
content_hash
in
self
.
_unused_cached_blocks
del
self
.
_unused_cached_blocks
[
content_hash
]
def
free
(
self
,
block
:
Block
)
->
None
:
"""Decrement the refcount of the block. If the decremented refcount is
zero, store the block in the freelist.
If the block has a content hash (meaning it is immutable), then we will
keep the block around in case future allocations require it.
"""
assert
(
block
.
block_id
is
not
None
),
"freeing unallocated block is undefined"
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
block
.
block_id
=
None
def
_free_block_id_for_block
(
self
,
block_id
:
BlockId
,
block
:
Block
)
->
None
:
assert
isinstance
(
block
,
PrefixCachingBlock
)
if
block
.
content_hash
is
None
:
return
self
.
_hashless_allocator
.
free
(
block
)
refcount
=
self
.
_refcounter
.
decr
(
block_id
)
# If no longer used, add the block to the unused cached blocks.
if
refcount
==
0
:
assert
block
.
content_hash
not
in
self
.
_unused_cached_blocks
assert
block
.
content_hash
in
self
.
_cached_blocks
self
.
_unused_cached_blocks
[
block
.
content_hash
]
=
block_id
def
fork
(
self
,
last_block
:
Block
)
->
List
[
Block
]:
"""Creates a new sequence of blocks that shares the same underlying
memory as the original sequence.
Args:
last_block (Block): The last block in the original sequence.
Returns:
List[Block]: The new sequence of blocks that shares the same memory
as the original sequence.
"""
source_blocks
=
get_all_blocks_recursively
(
last_block
)
forked_blocks
=
[]
prev_block
=
None
for
block
in
source_blocks
:
refcount
=
self
.
_refcounter
.
incr
(
block
.
block_id
)
assert
refcount
!=
1
,
"can't fork free'd block"
forked_blocks
.
append
(
self
.
_create_block
(
prev_block
=
prev_block
,
token_ids
=
block
.
token_ids
,
block_id
=
block
.
block_id
,
block_size
=
self
.
_block_size
,
allocator
=
self
,
))
prev_block
=
forked_blocks
[
-
1
]
return
forked_blocks
def
get_num_free_blocks
(
self
)
->
int
:
# The number of free blocks is the number of hashless free blocks
# plus the number of hashful blocks that are unused.
return
self
.
_hashless_allocator
.
get_num_free_blocks
()
+
len
(
self
.
_unused_cached_blocks
)
@
property
def
all_block_ids
(
self
)
->
frozenset
[
int
]:
return
self
.
_hashless_allocator
.
all_block_ids
def
promote_to_immutable_block
(
self
,
block
:
"PrefixCachingBlock"
)
->
BlockId
:
"""Once a mutable block is full, it can be promoted to an immutable
block. This means that its content can be referenced by future blocks
having the same prefix.
Note that if we already have a cached block with the same content, we
will replace the newly-promoted block's mapping with the existing cached
block.
Args:
block (PrefixCachingBlock): The mutable block to be promoted.
Returns:
BlockId: Either the original block index, or the block index of
the previously cached block matching the same content.
"""
assert
block
.
content_hash
is
not
None
assert
block
.
block_id
is
not
None
assert
self
.
_refcounter
.
get
(
block
.
block_id
)
>
0
# If the content hash does not have a corresponding cached block,
# set this block as the cached block.
if
block
.
content_hash
not
in
self
.
_cached_blocks
:
self
.
_cached_blocks
[
block
.
content_hash
]
=
block
.
block_id
else
:
self
.
_free_block_id_for_block
(
block
.
block_id
,
block
)
self
.
_incr_refcount_cached_block
(
block
.
content_hash
,
self
.
_cached_blocks
[
block
.
content_hash
])
return
self
.
_cached_blocks
[
block
.
content_hash
]
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
"""Performs a copy-on-write operation on the given block if it is not
appendable.
Args:
block (Block): The block to check for copy-on-write.
Returns:
Optional[BlockId]: The block index of the new block if a copy-on
-write operation was performed, or the original block index if
no copy-on-write was necessary.
"""
return
self
.
_cow_tracker
.
cow_block_if_not_appendable
(
block
)
def
clear_copy_on_writes
(
self
)
->
Dict
[
BlockId
,
List
[
BlockId
]]:
"""Returns the copy-on-write source->destination mapping and clears it.
Returns:
Dict[BlockId, List[BlockId]]: A dictionary mapping source
block indices to lists of destination block indices.
"""
return
self
.
_cow_tracker
.
clear_cows
()
def
mark_blocks_as_computed
(
self
)
->
None
:
"""Mark blocks as computed, used in prefix caching."""
# TODO Track computed blocks.
pass
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
"""
# TODO: Track computed blocks.
computed
=
lambda
block_id
:
False
# NOTE We exclude the last block to avoid the case where the entire
# prompt is cached. This would cause erroneous behavior in model
# runner.
ids_list
=
[
takewhile
(
lambda
block_id
:
computed
(
block_id
),
seq
[:
-
1
])
for
seq
in
seq_block_ids
]
return
commonprefix
([
ids
for
ids
in
ids_list
if
ids
!=
[]])
class
PrefixCachingBlock
(
Block
):
"""A block implementation that supports prefix caching.
The PrefixCachingBlock class represents a block of token IDs with prefix
caching capabilities. It wraps a NaiveBlock internally and provides
additional functionality for content hashing and promoting immutable blocks
with the prefix caching allocator.
Args:
prev_block (Optional[PrefixCachingBlock]): The previous block in the
sequence.
token_ids (List[int]): The initial token IDs to be stored in the block.
block_size (int): The maximum number of token IDs that can be stored in
the block.
prefix_caching_allocator (PrefixCachingBlockAllocator): The prefix
caching block allocator associated with this block.
block_id (Optional[int], optional): The physical block index
of this block. Defaults to None.
"""
def
__init__
(
self
,
prev_block
:
Optional
[
"PrefixCachingBlock"
],
token_ids
:
List
[
int
],
block_size
:
int
,
prefix_caching_allocator
:
PrefixCachingBlockAllocator
,
block_id
:
Optional
[
int
]
=
None
,
):
assert_prefix_caching_block_or_none
(
prev_block
)
self
.
_prev_block
=
prev_block
self
.
_cached_content_hash
:
Optional
[
int
]
=
None
self
.
_prefix_caching_allocator
=
prefix_caching_allocator
self
.
_block
=
NaiveBlock
(
prev_block
=
prev_block
,
token_ids
=
token_ids
,
block_size
=
block_size
,
block_id
=
block_id
,
allocator
=
prefix_caching_allocator
,
_cow_target
=
self
,
)
def
append_token_ids
(
self
,
token_ids
:
List
[
int
])
->
None
:
"""Appends the given token IDs to the block and registers the block as
immutable if the block becomes full.
Internally, the naive block handles CoW.
Args:
token_ids (List[int]): The token IDs to be appended to the block.
"""
assert
token_ids
# naive block handles CoW.
self
.
_block
.
append_token_ids
(
token_ids
)
# If the content hash is present, then the block can be made immutable.
# Register ourselves with the allocator, potentially replacing the
# physical block index.
if
self
.
content_hash
is
not
None
:
self
.
block_id
=
(
self
.
_prefix_caching_allocator
.
promote_to_immutable_block
(
self
))
@
property
def
block_id
(
self
)
->
Optional
[
int
]:
return
self
.
_block
.
block_id
@
block_id
.
setter
def
block_id
(
self
,
value
)
->
None
:
self
.
_block
.
block_id
=
value
@
property
def
is_full
(
self
)
->
bool
:
return
self
.
_block
.
is_full
@
property
def
num_empty_slots
(
self
)
->
int
:
return
self
.
_block
.
num_empty_slots
@
property
def
block_size
(
self
)
->
int
:
return
self
.
_block
.
block_size
@
property
def
token_ids
(
self
)
->
List
[
int
]:
return
self
.
_block
.
token_ids
@
property
def
prev_block
(
self
)
->
Optional
[
Block
]:
return
self
.
_prev_block
@
property
def
content_hash
(
self
)
->
Optional
[
int
]:
"""Return the content-based hash of the current block, or None if it is
not yet defined.
For the content-based hash to be defined, the current block must be
full.
"""
# If the hash is already computed, return it.
if
self
.
_cached_content_hash
is
not
None
:
return
self
.
_cached_content_hash
# We cannot compute a hash for the current block because it is not full.
if
not
self
.
is_full
:
return
None
is_first_block
=
self
.
_prev_block
is
None
prev_block_hash
=
(
None
if
is_first_block
else
self
.
_prev_block
.
content_hash
)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
if
prev_block_hash
is
None
and
not
is_first_block
:
return
None
self
.
_cached_content_hash
=
PrefixCachingBlock
.
hash_block_tokens
(
is_first_block
,
prev_block_hash
,
cur_block_token_ids
=
self
.
token_ids
)
return
self
.
_cached_content_hash
@
staticmethod
def
hash_block_tokens
(
is_first_block
:
bool
,
prev_block_hash
:
Optional
[
int
],
cur_block_token_ids
:
List
[
int
])
->
int
:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching.
NOTE: Content-based hashing does not yet support LoRA.
Parameters:
- is_first_block (bool): A flag indicating if the block is the first in
the sequence.
- prev_block_hash (Optional[int]): The hash of the previous block. None
if this is the first block.
- cur_block_token_ids (List[int]): A list of token ids in the current
block. The current block is assumed to be full.
Returns:
- int: The computed hash value for the block.
"""
assert
(
prev_block_hash
is
None
)
==
is_first_block
return
hash
((
is_first_block
,
prev_block_hash
,
*
cur_block_token_ids
))
def
assert_prefix_caching_block_or_none
(
block
:
Optional
[
Block
]):
if
block
is
None
:
return
assert
isinstance
(
block
,
PrefixCachingBlock
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment