Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d4f39859
Unverified
Commit
d4f39859
authored
May 27, 2024
by
Michał Moskal
Committed by
GitHub
May 28, 2024
Browse files
[Core] Sliding window for block manager v2 (#4545)
Co-authored-by:
Ruth Evans
<
ruthevans@Ruths-MacBook-Pro.local
>
parent
890aa93d
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
457 additions
and
45 deletions
+457
-45
tests/core/block/e2e/conftest.py
tests/core/block/e2e/conftest.py
+26
-0
tests/core/block/e2e/test_correctness.py
tests/core/block/e2e/test_correctness.py
+2
-9
tests/core/block/e2e/test_correctness_sliding_window.py
tests/core/block/e2e/test_correctness_sliding_window.py
+168
-0
tests/core/block/test_block_manager_v2.py
tests/core/block/test_block_manager_v2.py
+69
-0
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+5
-1
vllm/core/block/block_table.py
vllm/core/block/block_table.py
+32
-2
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+74
-0
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+9
-0
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+17
-7
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-1
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+4
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+49
-24
No files found.
tests/core/block/e2e/conftest.py
View file @
d4f39859
from
typing
import
Callable
,
Iterable
,
Optional
import
pytest
import
pytest
from
vllm
import
LLM
from
vllm
import
LLM
...
@@ -40,3 +42,27 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
...
@@ -40,3 +42,27 @@ def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
for
llm
in
generator_inner
():
for
llm
in
generator_inner
():
yield
llm
yield
llm
del
llm
del
llm
def
get_text_from_llm_generator
(
llm_generator
:
Iterable
[
LLM
],
prompts
,
sampling_params
,
llm_cb
:
Optional
[
Callable
[[
LLM
],
None
]]
=
None
):
for
llm
in
llm_generator
:
if
llm_cb
:
llm_cb
(
llm
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
text
=
[
output
.
outputs
[
0
].
text
for
output
in
outputs
]
del
llm
return
text
def
get_token_ids_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
):
for
llm
in
llm_generator
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
del
llm
return
token_ids
tests/core/block/e2e/test_correctness.py
View file @
d4f39859
...
@@ -4,6 +4,8 @@ import pytest
...
@@ -4,6 +4,8 @@ import pytest
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
.conftest
import
get_token_ids_from_llm_generator
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
"common_llm_kwargs"
,
...
@@ -444,12 +446,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
...
@@ -444,12 +446,3 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
assert
expected_token_ids
==
actual_token_ids
assert
expected_token_ids
==
actual_token_ids
assert
baseline_token_ids
==
test_token_ids
assert
baseline_token_ids
==
test_token_ids
def
get_token_ids_from_llm_generator
(
llm_generator
,
prompts
,
sampling_params
):
for
llm
in
llm_generator
:
outputs
=
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
token_ids
=
[
output
.
outputs
[
0
].
token_ids
for
output
in
outputs
]
del
llm
return
token_ids
tests/core/block/e2e/test_correctness_sliding_window.py
0 → 100644
View file @
d4f39859
import
random
from
typing
import
List
import
pytest
from
vllm
import
LLM
,
SamplingParams
from
.conftest
import
get_text_from_llm_generator
# relatively small model with 4k sliding window
MODEL
=
"bigcode/starcoder2-3b"
BLOCK_SIZE
=
16
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
MODEL
,
# skip cuda graph creation for fast test.
"enforce_eager"
:
True
,
"block_size"
:
BLOCK_SIZE
,
# needed due to https://github.com/vllm-project/vllm/issues/1908#issuecomment-2101122008
"num_gpu_blocks_override"
:
100000
//
BLOCK_SIZE
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{
"use_v2_block_manager"
:
False
}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"use_v2_block_manager"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_sliding_window_retrival
(
baseline_llm_generator
,
test_llm_generator
,
batch_size
,
seed
):
"""
The test does a bunch of assignments "x1 = 10
\n
x2 = 33
\n
..." and then
asks for value of one of them (which is outside the sliding window).
If we tell it upfront which we are going to be looking for, then
it answers correctly (mostly).
Additionally, we compare the results of the v1 and v2 managers.
"""
sampling_params
=
SamplingParams
(
max_tokens
=
1024
,
ignore_eos
=
True
,
temperature
=
0.0
,
)
prompts
,
answer
,
indices
=
prep_prompts
(
batch_size
)
print
(
'Getting token ids from block manager v1'
)
baseline_texts
=
get_text_from_llm_generator
(
baseline_llm_generator
,
prompts
,
sampling_params
,
llm_cb
=
check_window
(
prompts
))
check_answers
(
indices
,
answer
,
baseline_texts
)
print
(
'Getting token ids from block manager v2'
)
test_texts
=
get_text_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
)
check_answers
(
indices
,
answer
,
test_texts
)
cmp
=
[
expected_text
==
actual_text
for
expected_text
,
actual_text
in
zip
(
baseline_texts
,
test_texts
)
]
print
(
cmp
)
# make sure it's mostly OK; this is possibly because https://github.com/vllm-project/vllm/pull/4768
# however, https://github.com/vllm-project/vllm/issues/3385#issuecomment-1995924290
# states that xformers and flash_attn have different ideas about the window
# size anyways
assert
sum
(
cmp
)
>
0.7
*
len
(
cmp
)
@
pytest
.
mark
.
parametrize
(
"common_llm_kwargs"
,
[{
"model"
:
MODEL
,
# skip cuda graph creation for fast test.
"enforce_eager"
:
True
,
"block_size"
:
BLOCK_SIZE
,
"num_gpu_blocks_override"
:
100000
//
BLOCK_SIZE
,
}])
@
pytest
.
mark
.
parametrize
(
"per_test_common_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"test_llm_kwargs"
,
[{
"use_v2_block_manager"
:
True
,
"enable_chunked_prefill"
:
True
}])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"seed"
,
[
1
])
def
test_sliding_window_chunked_prefill
(
test_llm_generator
,
batch_size
,
seed
):
"""
This is similar to test_sliding_window_retrival, however, it doesn't
compare against the v1 block manager since v1 doesn't support
chunked prefill with sliding window.
The results with and without chunked prefill are not the same due to
numerical instabilities.
"""
sampling_params
=
SamplingParams
(
max_tokens
=
10
,
ignore_eos
=
True
,
temperature
=
0.0
,
)
prompts
,
answer
,
indices
=
prep_prompts
(
batch_size
)
# We don't compare with the baseline model here, since the results
# slightly different due to different tailing in attention.
test_texts
=
get_text_from_llm_generator
(
test_llm_generator
,
prompts
,
sampling_params
,
llm_cb
=
check_window
(
prompts
))
check_answers
(
indices
,
answer
,
test_texts
)
def
prep_prompts
(
batch_size
:
int
):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.
"""
prompts
:
List
[
str
]
=
[]
answer
:
List
[
int
]
=
[]
indices
:
List
[
int
]
=
[]
random
.
seed
(
1
)
for
_
in
range
(
batch_size
):
idx
=
random
.
randint
(
30
,
90
)
indices
.
append
(
idx
)
prompt
=
"```python
\n
# We set a number of variables, "
+
\
f
"x
{
idx
}
will be important later
\n
"
ln
=
random
.
randint
(
800
,
1100
)
for
k
in
range
(
30
,
ln
):
v
=
random
.
randint
(
10
,
99
)
if
k
==
idx
:
answer
.
append
(
v
)
prompt
+=
f
"x
{
k
}
=
{
v
}
\n
"
prompt
+=
f
"# Now, we check the value of x
{
idx
}
:
\n
"
prompt
+=
f
"assert x
{
idx
}
== "
prompts
.
append
(
prompt
)
return
prompts
,
answer
,
indices
def
check_answers
(
indices
:
List
[
int
],
answer
:
List
[
int
],
outputs
:
List
[
str
]):
answer2
=
[
int
(
text
[
0
:
2
].
strip
())
for
text
in
outputs
]
print
(
list
(
zip
(
indices
,
zip
(
answer
,
answer2
))))
numok
=
0
for
a1
,
a2
in
zip
(
answer
,
answer2
):
if
a1
==
a2
:
numok
+=
1
frac_ok
=
numok
/
len
(
answer
)
print
(
f
"Num OK:
{
numok
}
/
{
len
(
answer
)
}
{
frac_ok
}
"
)
assert
frac_ok
>
0.7
def
check_window
(
prompts
:
List
[
str
]):
def
inner
(
llm
:
LLM
):
sliding_window
=
llm
.
llm_engine
.
model_config
.
get_sliding_window
()
assert
sliding_window
and
sliding_window
>
0
assert
any
(
len
(
llm
.
get_tokenizer
().
tokenize
(
prompt
))
>
sliding_window
for
prompt
in
prompts
)
return
inner
tests/core/block/test_block_manager_v2.py
View file @
d4f39859
...
@@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
...
@@ -101,3 +101,72 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append,
range
(
prompt_len
+
num_slots_to_append
+
num_lookahead_slots
)),
range
(
prompt_len
+
num_slots_to_append
+
num_lookahead_slots
)),
block_size
))
-
len
(
chunk_list
(
list
(
range
(
prompt_len
)),
block_size
))
block_size
))
-
len
(
chunk_list
(
list
(
range
(
prompt_len
)),
block_size
))
assert
num_consumed_blocks
==
expected_consumed_blocks
assert
num_consumed_blocks
==
expected_consumed_blocks
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"prompt_len"
,
[
10
,
300
,
1000
])
@
pytest
.
mark
.
parametrize
(
"num_slots_to_append"
,
[
50
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
20
,
32
,
200
,
512
])
def
test_sliding_window
(
block_size
,
prompt_len
,
num_slots_to_append
,
sliding_window
):
"""Verify append_slots consumes the correct number of blocks from the block
table.
"""
num_gpu_blocks
=
1024
watermark
=
0.1
block_manager
=
BlockSpaceManagerV2
(
block_size
=
block_size
,
num_gpu_blocks
=
num_gpu_blocks
,
num_cpu_blocks
=
0
,
watermark
=
watermark
,
sliding_window
=
sliding_window
,
)
def
check_used
(
min_n
,
max_n
=
None
):
if
max_n
is
None
:
max_n
=
min_n
used
=
num_gpu_blocks
-
block_manager
.
get_num_free_gpu_blocks
()
#print("check", min_n, used, max_n)
assert
min_n
<=
used
assert
used
<=
max_n
def
num_blocks
(
num_tokens
):
return
(
num_tokens
+
block_size
-
1
)
//
block_size
check_used
(
0
)
seq_group
=
create_seq_group
(
seq_prompt_len
=
prompt_len
,
seq_output_lens
=
[
0
],
)
check_used
(
0
)
# Allocate seq
assert
block_manager
.
can_allocate
(
seq_group
)
block_manager
.
allocate
(
seq_group
)
check_used
(
num_blocks
(
prompt_len
))
# Seq seq to RUNNING
seq
=
seq_group
.
get_seqs
()[
0
]
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
data
.
update_num_computed_tokens
(
prompt_len
)
check_used
(
num_blocks
(
prompt_len
))
# this is how we compute it in BlockSpaceManagerV2.__init__
sliding_blocks
=
(
sliding_window
//
block_size
)
+
2
# plus one block for null block
sliding_blocks
+=
1
# Append tokens to the sequeqnce
for
token_id
in
range
(
num_slots_to_append
):
seq
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
0.0
)})
seq
.
data
.
update_num_computed_tokens
(
1
)
block_manager
.
append_slots
(
seq
,
num_lookahead_slots
=
0
)
if
prompt_len
<
sliding_window
+
10
:
check_used
(
0
,
sliding_blocks
+
1
)
else
:
check_used
(
sliding_blocks
,
sliding_blocks
+
1
)
vllm/attention/ops/prefix_prefill.py
View file @
d4f39859
...
@@ -697,6 +697,10 @@ if triton.__version__ >= "2.1.0":
...
@@ -697,6 +697,10 @@ if triton.__version__ >= "2.1.0":
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
# 0 means "disable"
if
sliding_window
is
None
or
sliding_window
<=
0
:
sliding_window
=
0
num_warps
=
8
if
Lk
<=
64
else
8
num_warps
=
8
if
Lk
<=
64
else
8
if
alibi_slopes
is
not
None
:
if
alibi_slopes
is
not
None
:
_fwd_kernel_alibi
[
grid
](
_fwd_kernel_alibi
[
grid
](
...
@@ -794,7 +798,7 @@ if triton.__version__ >= "2.1.0":
...
@@ -794,7 +798,7 @@ if triton.__version__ >= "2.1.0":
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL
=
Lk
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_DMODEL_PADDED
=
Lk_padded
,
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
SLIDING_WINDOW
=
sliding_window
if
sliding_window
is
not
None
else
0
,
SLIDING_WINDOW
=
sliding_window
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
...
...
vllm/core/block/block_table.py
View file @
d4f39859
...
@@ -20,6 +20,10 @@ class BlockTable:
...
@@ -20,6 +20,10 @@ class BlockTable:
_blocks (Optional[List[Block]], optional): An optional list of existing
_blocks (Optional[List[Block]], optional): An optional list of existing
blocks to initialize the BlockTable with. If not provided, an empty
blocks to initialize the BlockTable with. If not provided, an empty
BlockTable is created.
BlockTable is created.
max_block_sliding_window (Optional[int], optional): The number of
blocks to keep around for each sequance. If None, all blocks
are kept (eg., when sliding window is not used).
It should at least fit the sliding window size of the model.
Attributes:
Attributes:
_block_size (int): The maximum number of tokens that can be stored in a
_block_size (int): The maximum number of tokens that can be stored in a
...
@@ -37,6 +41,7 @@ class BlockTable:
...
@@ -37,6 +41,7 @@ class BlockTable:
block_size
:
int
,
block_size
:
int
,
block_allocator
:
DeviceAwareBlockAllocator
,
block_allocator
:
DeviceAwareBlockAllocator
,
_blocks
:
Optional
[
List
[
Block
]]
=
None
,
_blocks
:
Optional
[
List
[
Block
]]
=
None
,
max_block_sliding_window
:
Optional
[
int
]
=
None
,
):
):
self
.
_block_size
=
block_size
self
.
_block_size
=
block_size
self
.
_allocator
=
block_allocator
self
.
_allocator
=
block_allocator
...
@@ -44,6 +49,7 @@ class BlockTable:
...
@@ -44,6 +49,7 @@ class BlockTable:
_blocks
=
[]
_blocks
=
[]
self
.
_blocks
:
List
[
Block
]
=
_blocks
self
.
_blocks
:
List
[
Block
]
=
_blocks
self
.
_max_block_sliding_window
=
max_block_sliding_window
# Use helper method instead of directly calculating, as blocks
# Use helper method instead of directly calculating, as blocks
# may not be allocated.
# may not be allocated.
self
.
_num_full_slots
=
len
(
self
.
_get_all_token_ids
())
self
.
_num_full_slots
=
len
(
self
.
_get_all_token_ids
())
...
@@ -89,7 +95,8 @@ class BlockTable:
...
@@ -89,7 +95,8 @@ class BlockTable:
def
append_token_ids
(
self
,
def
append_token_ids
(
self
,
token_ids
:
List
[
int
],
token_ids
:
List
[
int
],
num_lookahead_slots
:
int
=
0
)
->
None
:
num_lookahead_slots
:
int
=
0
,
num_computed_slots
:
Optional
[
int
]
=
None
)
->
None
:
"""Appends a sequence of token IDs to the existing blocks in the
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
BlockTable.
...
@@ -104,13 +111,35 @@ class BlockTable:
...
@@ -104,13 +111,35 @@ class BlockTable:
Args:
Args:
token_ids (List[int]): The sequence of token IDs to be appended.
token_ids (List[int]): The sequence of token IDs to be appended.
num_computed_slots (Optional[int]): The number of KV cache slots
that are already filled (computed).
When sliding window is enabled, this is used to compute how many
blocks to drop at the front of the sequence.
Without sliding window, None can be passed.
Without chunked prefill, it should be the same as
_num_full_slots.
"""
"""
assert
self
.
_is_allocated
assert
self
.
_is_allocated
,
"no blocks have been allocated"
assert
len
(
self
.
_blocks
)
>
0
assert
len
(
self
.
_blocks
)
>
0
# Drop blocks that are no longer needed due to sliding window
if
self
.
_max_block_sliding_window
is
not
None
:
null_block
=
self
.
_allocator
.
allocate_or_get_null_block
()
assert
num_computed_slots
is
not
None
end_block_idx
=
(
num_computed_slots
//
self
.
_block_size
)
-
self
.
_max_block_sliding_window
for
idx
in
range
(
0
,
end_block_idx
):
b
=
self
.
_blocks
[
idx
]
if
b
is
not
null_block
:
self
.
_allocator
.
free
(
b
)
self
.
_blocks
[
idx
]
=
null_block
# Ensure there are enough empty slots for the new tokens plus
# lookahead slots
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
self
.
ensure_num_empty_slots
(
num_empty_slots
=
len
(
token_ids
)
+
num_lookahead_slots
)
num_lookahead_slots
)
# Update the blocks with the new tokens
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
blocks
=
self
.
_blocks
[
self
.
_num_full_slots
//
self
.
_block_size
:]
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
token_blocks
=
self
.
_chunk_token_blocks_for_append
(
token_ids
)
...
@@ -168,6 +197,7 @@ class BlockTable:
...
@@ -168,6 +197,7 @@ class BlockTable:
block_size
=
self
.
_block_size
,
block_size
=
self
.
_block_size
,
block_allocator
=
self
.
_allocator
,
block_allocator
=
self
.
_allocator
,
_blocks
=
forked_blocks
,
_blocks
=
forked_blocks
,
max_block_sliding_window
=
self
.
_max_block_sliding_window
,
)
)
def
free
(
self
)
->
None
:
def
free
(
self
)
->
None
:
...
...
vllm/core/block/cpu_gpu_block_allocator.py
View file @
d4f39859
...
@@ -105,11 +105,19 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -105,11 +105,19 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device
.
GPU
:
gpu_block_allocator
,
Device
.
GPU
:
gpu_block_allocator
,
}
}
self
.
_null_block
:
Optional
[
Block
]
=
None
self
.
_block_ids_to_allocator
:
Dict
[
int
,
BlockAllocator
]
=
{}
self
.
_block_ids_to_allocator
:
Dict
[
int
,
BlockAllocator
]
=
{}
for
_
,
allocator
in
self
.
_allocators
.
items
():
for
_
,
allocator
in
self
.
_allocators
.
items
():
for
block_id
in
allocator
.
all_block_ids
:
for
block_id
in
allocator
.
all_block_ids
:
self
.
_block_ids_to_allocator
[
block_id
]
=
allocator
self
.
_block_ids_to_allocator
[
block_id
]
=
allocator
def
allocate_or_get_null_block
(
self
)
->
Block
:
if
self
.
_null_block
is
None
:
self
.
_null_block
=
NullBlock
(
self
.
allocate_mutable
(
None
,
Device
.
GPU
))
return
self
.
_null_block
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_mutable
(
self
,
prev_block
:
Optional
[
Block
],
device
:
Device
)
->
Block
:
device
:
Device
)
->
Block
:
"""Allocates a new mutable block on the specified device.
"""Allocates a new mutable block on the specified device.
...
@@ -149,6 +157,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -149,6 +157,9 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Args:
Args:
block (Block): The block to be freed.
block (Block): The block to be freed.
"""
"""
# Null block should never be freed
if
isinstance
(
block
,
NullBlock
):
return
block_id
=
block
.
block_id
block_id
=
block
.
block_id
assert
block_id
is
not
None
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
...
@@ -165,6 +176,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -165,6 +176,8 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
List[Block]: A new list of blocks that shares the same memory as the
List[Block]: A new list of blocks that shares the same memory as the
original sequence.
original sequence.
"""
"""
# do not attempt to fork the null block
assert
not
isinstance
(
last_block
,
NullBlock
)
block_id
=
last_block
.
block_id
block_id
=
last_block
.
block_id
assert
block_id
is
not
None
assert
block_id
is
not
None
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
allocator
=
self
.
_block_ids_to_allocator
[
block_id
]
...
@@ -226,3 +239,64 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -226,3 +239,64 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
def
cow_block_if_not_appendable
(
self
,
block
:
Block
)
->
Optional
[
BlockId
]:
raise
NotImplementedError
raise
NotImplementedError
class
NullBlock
(
Block
):
"""
Null blocks are used as a placeholders for KV cache blocks that have
been dropped due to sliding window.
This implementation just wraps an ordinary block and prevents it from
being modified. It also allows for testing if a block is NullBlock
via isinstance().
"""
def
__init__
(
self
,
proxy
:
Block
):
super
().
__init__
()
self
.
_proxy
=
proxy
def
append_token_ids
(
self
,
token_ids
:
List
[
BlockId
]):
raise
ValueError
(
"null block should not be modified"
)
@
property
def
block_id
(
self
):
return
self
.
_proxy
.
block_id
@
block_id
.
setter
def
block_id
(
self
,
value
:
Optional
[
BlockId
]):
raise
ValueError
(
"null block should not be modified"
)
@
property
def
token_ids
(
self
)
->
List
[
BlockId
]:
return
self
.
_proxy
.
token_ids
@
property
def
num_empty_slots
(
self
)
->
BlockId
:
return
self
.
_proxy
.
num_empty_slots
@
property
def
is_full
(
self
):
return
self
.
_proxy
.
is_full
@
property
def
prev_block
(
self
):
return
self
.
_proxy
.
prev_block
@
property
def
computed
(
self
):
return
self
.
_proxy
.
computed
@
computed
.
setter
def
computed
(
self
,
value
):
self
.
_proxy
.
computed
=
value
@
property
def
last_accessed
(
self
)
->
float
:
return
self
.
_proxy
.
last_accessed
@
last_accessed
.
setter
def
last_accessed
(
self
,
last_accessed_ts
:
float
):
self
.
_proxy
.
last_accessed
=
last_accessed_ts
@
property
def
content_hash
(
self
):
return
self
.
_proxy
.
content_hash
vllm/core/block/interfaces.py
View file @
d4f39859
...
@@ -203,3 +203,12 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -203,3 +203,12 @@ class DeviceAwareBlockAllocator(ABC):
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
pass
pass
@
abstractmethod
def
allocate_or_get_null_block
(
self
)
->
Block
:
"""
Null blocks are used as a placeholders for KV cache blocks that have
been dropped due to sliding window.
There is at most one null block per allocator.
"""
pass
vllm/core/block_manager_v2.py
View file @
d4f39859
...
@@ -66,9 +66,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -66,9 +66,18 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self
.
num_total_gpu_blocks
=
num_gpu_blocks
self
.
num_total_gpu_blocks
=
num_gpu_blocks
self
.
num_total_cpu_blocks
=
num_cpu_blocks
self
.
num_total_cpu_blocks
=
num_cpu_blocks
assert
sliding_window
is
None
,
"Sliding window not yet supported"
self
.
sliding_window
=
sliding_window
# max_block_sliding_window is the max number of blocks that need to be
self
.
block_sliding_window
=
None
# allocated
self
.
max_block_sliding_window
=
None
if
sliding_window
is
not
None
:
# +1 here because // rounds down
num_blocks
=
sliding_window
//
block_size
+
1
# +1 here because the last block may not be full,
# and so the sequence stretches one more block at the beginning
# For example, if sliding_window is 3 and block_size is 4,
# we may need 2 blocks when the second block only holds 1 token.
self
.
max_block_sliding_window
=
num_blocks
+
1
self
.
watermark
=
watermark
self
.
watermark
=
watermark
assert
watermark
>=
0.0
assert
watermark
>=
0.0
...
@@ -96,10 +105,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -96,10 +105,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
)
)
assert
self
.
block_sliding_window
is
None
if
self
.
max_block_sliding_window
is
not
None
:
if
self
.
block_sliding_window
is
not
None
:
num_required_blocks
=
min
(
num_required_blocks
,
num_required_blocks
=
min
(
num_required_blocks
,
self
.
block_sliding_window
)
self
.
max_
block_sliding_window
)
num_free_gpu_blocks
=
self
.
block_allocator
.
get_num_free_blocks
(
num_free_gpu_blocks
=
self
.
block_allocator
.
get_num_free_blocks
(
device
=
Device
.
GPU
)
device
=
Device
.
GPU
)
...
@@ -125,8 +133,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -125,8 +133,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_table
=
BlockTable
(
block_table
=
BlockTable
(
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
block_allocator
=
self
.
block_allocator
,
block_allocator
=
self
.
block_allocator
,
max_block_sliding_window
=
self
.
max_block_sliding_window
,
)
)
assert
self
.
block_sliding_window
is
None
block_table
.
allocate
(
seq
.
get_token_ids
())
block_table
.
allocate
(
seq
.
get_token_ids
())
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
...
@@ -174,6 +183,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -174,6 +183,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
block_table
.
append_token_ids
(
block_table
.
append_token_ids
(
token_ids
=
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
token_ids
=
block_table
.
get_unseen_token_ids
(
seq
.
get_token_ids
()),
num_lookahead_slots
=
num_lookahead_slots
,
num_lookahead_slots
=
num_lookahead_slots
,
num_computed_slots
=
seq
.
data
.
get_num_computed_tokens
(),
)
)
# Return any new copy-on-writes.
# Return any new copy-on-writes.
...
...
vllm/engine/arg_utils.py
View file @
d4f39859
...
@@ -648,7 +648,8 @@ class EngineArgs:
...
@@ -648,7 +648,8 @@ class EngineArgs:
guided_decoding_backend
=
self
.
guided_decoding_backend
)
guided_decoding_backend
=
self
.
guided_decoding_backend
)
if
(
model_config
.
get_sliding_window
()
is
not
None
if
(
model_config
.
get_sliding_window
()
is
not
None
and
scheduler_config
.
chunked_prefill_enabled
):
and
scheduler_config
.
chunked_prefill_enabled
and
not
scheduler_config
.
use_v2_block_manager
):
raise
ValueError
(
raise
ValueError
(
"Chunked prefill is not supported with sliding window. "
"Chunked prefill is not supported with sliding window. "
"Set --disable-sliding-window to disable sliding window."
)
"Set --disable-sliding-window to disable sliding window."
)
...
...
vllm/worker/cache_engine.py
View file @
d4f39859
...
@@ -68,8 +68,11 @@ class CacheEngine:
...
@@ -68,8 +68,11 @@ class CacheEngine:
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
pin_memory
=
is_pin_memory_available
()
if
device
==
"cpu"
else
False
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
kv_cache
:
List
[
torch
.
Tensor
]
=
[]
for
_
in
range
(
self
.
num_layers
):
for
_
in
range
(
self
.
num_layers
):
# null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out.
# We zero-out everything for simplicity.
kv_cache
.
append
(
kv_cache
.
append
(
torch
.
empty
(
kv_cache_shape
,
torch
.
zeros
(
kv_cache_shape
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
device
=
device
))
device
=
device
))
...
...
vllm/worker/model_runner.py
View file @
d4f39859
...
@@ -269,6 +269,12 @@ class ModelRunner:
...
@@ -269,6 +269,12 @@ class ModelRunner:
if
len
(
seq_group_metadata_list
)
==
0
:
if
len
(
seq_group_metadata_list
)
==
0
:
return
ModelInput
.
empty
(
self
.
device
)
return
ModelInput
.
empty
(
self
.
device
)
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
+
self
.
block_size
-
1
)
//
self
.
block_size
block_aligned_sliding_window
=
\
sliding_window_blocks
*
self
.
block_size
for
seq_group_metadata
in
seq_group_metadata_list
:
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
is_prompt
=
seq_group_metadata
.
is_prompt
is_prompt
=
seq_group_metadata
.
is_prompt
...
@@ -309,6 +315,30 @@ class ModelRunner:
...
@@ -309,6 +315,30 @@ class ModelRunner:
and
self
.
sliding_window
is
None
and
self
.
sliding_window
is
None
and
is_prompt
)
and
is_prompt
)
# These are seq_len/context_len capped to the sliding window.
# They are passed to decode kernel.
# We still need original seq_len/context_len to compute slot
# mapping (and input position) below.
curr_sliding_window_blocks
=
None
sliding_seq_len
=
seq_len
sliding_context_len
=
context_len
# TODO(sang): This is a hack to make sliding window work with
# paged attn. We can remove it if we make paged attn kernel
# to properly handle slinding window attn.
if
(
self
.
sliding_window
is
not
None
and
not
is_prompt
):
curr_sliding_window_blocks
=
sliding_window_blocks
if
self
.
scheduler_config
.
use_v2_block_manager
:
# number of elements in last block
suff_len
=
seq_len
%
self
.
block_size
sliding_seq_len
=
min
(
seq_len
,
block_aligned_sliding_window
+
suff_len
)
if
suff_len
>
0
:
curr_sliding_window_blocks
+=
1
else
:
sliding_seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
sliding_context_len
=
sliding_seq_len
-
1
# TODO(sang): Combine chunked prefill and prefix caching by
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
...
@@ -316,6 +346,13 @@ class ModelRunner:
...
@@ -316,6 +346,13 @@ class ModelRunner:
assert
computed_block_nums
is
not
None
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
tokens
=
tokens
[
context_len
:]
tokens
=
tokens
[
context_len
:]
# need to think what to set it to when we have both sliding
# window and prefix caching...
assert
self
.
sliding_window
is
None
,
\
"Prefix caching is not supported with sliding window"
sliding_context_len
=
context_len
if
self
.
attn_backend
.
get_name
()
==
"flash-attn"
:
if
self
.
attn_backend
.
get_name
()
==
"flash-attn"
:
# NOTE(woosuk): For flash-attn, the block table should
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# include the entries for the incoming prefill tokens.
...
@@ -329,14 +366,9 @@ class ModelRunner:
...
@@ -329,14 +366,9 @@ class ModelRunner:
if
seq_group_metadata
.
block_tables
is
not
None
:
if
seq_group_metadata
.
block_tables
is
not
None
:
# chunked prefill or decode
# chunked prefill or decode
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
if
self
.
sliding_window
is
not
None
:
if
curr_sliding_window_blocks
is
not
None
:
# chunked prefill doesn't support sliding window.
block_table
=
block_table
[
assert
(
not
self
.
scheduler_config
.
-
curr_sliding_window_blocks
:]
chunked_prefill_enabled
)
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
paged_kv_indices
.
extend
(
block_table
)
paged_kv_indices
.
extend
(
block_table
)
paged_kv_indptr
.
append
(
paged_kv_indptr
[
-
1
]
+
paged_kv_indptr
.
append
(
paged_kv_indptr
[
-
1
]
+
...
@@ -354,16 +386,9 @@ class ModelRunner:
...
@@ -354,16 +386,9 @@ class ModelRunner:
block_table
=
[]
block_table
=
[]
block_tables
.
append
(
block_table
)
block_tables
.
append
(
block_table
)
# TODO(sang): This is a hack to make sliding window work with
seq_lens
.
append
(
sliding_seq_len
)
# paged attn. We can remove it if we make paged attn kernel
context_lens
.
append
(
sliding_context_len
)
# to properly handle slinding window attn.
query_len
=
sliding_seq_len
-
sliding_context_len
if
(
self
.
sliding_window
is
not
None
and
not
is_prompt
):
seq_len
=
min
(
seq_len
,
self
.
sliding_window
)
context_len
=
seq_len
-
1
seq_lens
.
append
(
seq_len
)
context_lens
.
append
(
context_len
)
query_len
=
seq_len
-
context_len
query_lens
.
append
(
query_len
)
query_lens
.
append
(
query_len
)
input_tokens
.
extend
(
tokens
)
input_tokens
.
extend
(
tokens
)
input_positions
.
extend
(
list
(
range
(
context_len
,
seq_len
)))
input_positions
.
extend
(
list
(
range
(
context_len
,
seq_len
)))
...
@@ -380,16 +405,15 @@ class ModelRunner:
...
@@ -380,16 +405,15 @@ class ModelRunner:
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
"seq_len: {}, context_len: {}, query_len: {}"
.
format
(
seq_len
,
context_len
,
query_len
))
seq_len
,
context_len
,
query_len
))
num_decode_tokens
+=
query_len
num_decode_tokens
+=
query_len
decode_seq_lens
.
append
(
seq_len
)
decode_seq_lens
.
append
(
sliding_
seq_len
)
if
lora_id
>
0
:
if
lora_id
>
0
:
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_requests
.
add
(
seq_group_metadata
.
lora_request
)
lora_index_mapping
+=
[
lora_id
]
*
(
seq_len
-
context
_len
)
lora_index_mapping
+=
[
lora_id
]
*
query
_len
lora_prompt_mapping
.
extend
(
lora_prompt_mapping
.
extend
(
[
lora_id
]
*
[
lora_id
]
*
(
seq_len
-
(
query_len
if
seq_group_metadata
.
sampling_params
context_len
if
seq_group_metadata
.
sampling_params
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
and
seq_group_metadata
.
sampling_params
.
prompt_logprobs
else
1
))
else
1
))
...
@@ -417,9 +441,10 @@ class ModelRunner:
...
@@ -417,9 +441,10 @@ class ModelRunner:
start_idx
=
0
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
if
self
.
sliding_window
is
not
None
:
if
is_prompt
:
if
is_prompt
:
assert
context_len
==
0
,
(
assert
self
.
scheduler_config
.
use_v2_block_manager
\
or
context_len
==
0
,
(
"Prefix caching is currently not supported with "
"Prefix caching is currently not supported with "
"sliding window attention"
)
"sliding window attention
in V1 block manager
"
)
# It is an optimization. When it is decoding, it is always
# It is an optimization. When it is decoding, it is always
# 0. When prefill, we use it to not write slots to kv cache
# 0. When prefill, we use it to not write slots to kv cache
# to save memory.
# to save memory.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment