Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
20cfcdec
Unverified
Commit
20cfcdec
authored
May 08, 2024
by
youkaichao
Committed by
GitHub
May 08, 2024
Browse files
[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)
parent
ad932a22
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
117 additions
and
102 deletions
+117
-102
csrc/cache.h
csrc/cache.h
+2
-2
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+11
-5
csrc/cpu/cache.cpp
csrc/cpu/cache.cpp
+2
-2
tests/core/test_block_manager.py
tests/core/test_block_manager.py
+2
-2
tests/core/test_chunked_prefill_scheduler.py
tests/core/test_chunked_prefill_scheduler.py
+16
-16
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+17
-17
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+9
-4
tests/worker/test_swap.py
tests/worker/test_swap.py
+12
-12
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+1
-1
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+2
-2
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+1
-1
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+2
-2
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+2
-2
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+2
-2
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+8
-4
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+2
-2
vllm/core/interfaces.py
vllm/core/interfaces.py
+3
-3
vllm/core/scheduler.py
vllm/core/scheduler.py
+16
-16
vllm/sequence.py
vllm/sequence.py
+4
-4
vllm/worker/cache_engine.py
vllm/worker/cache_engine.py
+3
-3
No files found.
csrc/cache.h
View file @
20cfcdec
...
@@ -8,12 +8,12 @@
...
@@ -8,12 +8,12 @@
void
swap_blocks
(
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>
&
block_mapping
);
const
torch
::
Tensor
&
block_mapping
);
void
copy_blocks
(
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
torch
::
Tensor
&
block_mapping
);
const
torch
::
Tensor
&
block_mapping
);
void
reshape_and_cache
(
void
reshape_and_cache
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
key
,
...
...
csrc/cache_kernels.cu
View file @
20cfcdec
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
void
swap_blocks
(
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>
&
block_mapping
)
{
const
torch
::
Tensor
&
block_mapping
)
{
torch
::
Device
src_device
=
src
.
device
();
torch
::
Device
src_device
=
src
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
torch
::
Device
dst_device
=
dst
.
device
();
cudaMemcpyKind
memcpy_type
;
cudaMemcpyKind
memcpy_type
;
...
@@ -40,6 +40,11 @@ void swap_blocks(
...
@@ -40,6 +40,11 @@ void swap_blocks(
TORCH_CHECK
(
false
,
"Invalid device combination"
);
TORCH_CHECK
(
false
,
"Invalid device combination"
);
}
}
// NOTE(youkaichao): keep in mind that `block_mapping` should be
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK
(
block_mapping
.
device
().
is_cpu
(),
"block_mapping must be on CPU"
);
char
*
src_ptr
=
static_cast
<
char
*>
(
src
.
data_ptr
());
char
*
src_ptr
=
static_cast
<
char
*>
(
src
.
data_ptr
());
char
*
dst_ptr
=
static_cast
<
char
*>
(
dst
.
data_ptr
());
char
*
dst_ptr
=
static_cast
<
char
*>
(
dst
.
data_ptr
());
...
@@ -47,9 +52,10 @@ void swap_blocks(
...
@@ -47,9 +52,10 @@ void swap_blocks(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
src_device
.
is_cuda
()
?
src_device
:
dst_device
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
src_device
.
is_cuda
()
?
src_device
:
dst_device
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// NOTE(woosuk): This can be slow if the number of blocks is large.
// NOTE(woosuk): This can be slow if the number of blocks is large.
for
(
const
auto
&
pair
:
block_mapping
)
{
const
int64_t
num_blocks
=
block_mapping
.
size
(
0
);
int64_t
src_block_number
=
pair
.
first
;
for
(
size_t
i
=
0
;
i
<
num_blocks
;
i
++
)
{
int64_t
dst_block_number
=
pair
.
second
;
int64_t
src_block_number
=
block_mapping
[
i
][
0
].
item
<
int64_t
>
();
int64_t
dst_block_number
=
block_mapping
[
i
][
1
].
item
<
int64_t
>
();
int64_t
src_offset
=
src_block_number
*
block_size_in_bytes
;
int64_t
src_offset
=
src_block_number
*
block_size_in_bytes
;
int64_t
dst_offset
=
dst_block_number
*
block_size_in_bytes
;
int64_t
dst_offset
=
dst_block_number
*
block_size_in_bytes
;
cudaMemcpyAsync
(
cudaMemcpyAsync
(
...
@@ -97,7 +103,7 @@ __global__ void copy_blocks_kernel(
...
@@ -97,7 +103,7 @@ __global__ void copy_blocks_kernel(
void
copy_blocks
(
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
key_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
std
::
vector
<
torch
::
Tensor
>&
value_caches
,
torch
::
Tensor
&
block_mapping
)
{
const
torch
::
Tensor
&
block_mapping
)
{
int
num_layers
=
key_caches
.
size
();
int
num_layers
=
key_caches
.
size
();
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
if
(
num_layers
==
0
)
{
if
(
num_layers
==
0
)
{
...
...
csrc/cpu/cache.cpp
View file @
20cfcdec
...
@@ -83,7 +83,7 @@ void reshape_and_cache_cpu_impl(
...
@@ -83,7 +83,7 @@ void reshape_and_cache_cpu_impl(
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>
&
key_caches
,
void
copy_blocks
(
std
::
vector
<
torch
::
Tensor
>
&
key_caches
,
std
::
vector
<
torch
::
Tensor
>
&
value_caches
,
std
::
vector
<
torch
::
Tensor
>
&
value_caches
,
torch
::
Tensor
&
block_mapping
)
{
const
torch
::
Tensor
&
block_mapping
)
{
int
num_layers
=
key_caches
.
size
();
int
num_layers
=
key_caches
.
size
();
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
TORCH_CHECK
(
num_layers
==
value_caches
.
size
());
if
(
num_layers
==
0
)
{
if
(
num_layers
==
0
)
{
...
@@ -128,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
...
@@ -128,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
}
}
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
void
swap_blocks
(
torch
::
Tensor
&
src
,
torch
::
Tensor
&
dst
,
const
std
::
map
<
int64_t
,
int64_t
>
&
block_mapping
)
{
const
torch
::
Tensor
&
block_mapping
)
{
TORCH_CHECK
(
false
,
"swap_blocks is unsupported on CPU."
)
TORCH_CHECK
(
false
,
"swap_blocks is unsupported on CPU."
)
}
}
tests/core/test_block_manager.py
View file @
20cfcdec
...
@@ -219,7 +219,7 @@ def test_swap():
...
@@ -219,7 +219,7 @@ def test_swap():
before_cpu_blocks
=
block_manager
.
get_num_free_cpu_blocks
()
before_cpu_blocks
=
block_manager
.
get_num_free_cpu_blocks
()
before_gpu_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
before_gpu_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
mapping
=
block_manager
.
swap_out
(
seq_group
)
mapping
=
block_manager
.
swap_out
(
seq_group
)
assert
list
(
mapping
.
keys
())
==
gpu_blocks
assert
[
x
[
0
]
for
x
in
mapping
]
==
gpu_blocks
after_cpu_blocks
=
block_manager
.
get_num_free_cpu_blocks
()
after_cpu_blocks
=
block_manager
.
get_num_free_cpu_blocks
()
after_gpu_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
after_gpu_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
assert
before_cpu_blocks
==
after_cpu_blocks
+
len
(
gpu_blocks
)
assert
before_cpu_blocks
==
after_cpu_blocks
+
len
(
gpu_blocks
)
...
@@ -232,7 +232,7 @@ def test_swap():
...
@@ -232,7 +232,7 @@ def test_swap():
before_cpu_blocks
=
block_manager
.
get_num_free_cpu_blocks
()
before_cpu_blocks
=
block_manager
.
get_num_free_cpu_blocks
()
before_gpu_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
before_gpu_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
mapping
=
block_manager
.
swap_in
(
seq_group
)
mapping
=
block_manager
.
swap_in
(
seq_group
)
assert
list
(
mapping
.
keys
())
==
cpu_blocks
assert
[
x
[
0
]
for
x
in
mapping
]
==
cpu_blocks
after_cpu_blocks
=
block_manager
.
get_num_free_cpu_blocks
()
after_cpu_blocks
=
block_manager
.
get_num_free_cpu_blocks
()
after_gpu_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
after_gpu_blocks
=
block_manager
.
get_num_free_gpu_blocks
()
assert
before_cpu_blocks
+
len
(
cpu_blocks
)
==
after_cpu_blocks
assert
before_cpu_blocks
+
len
(
cpu_blocks
)
==
after_cpu_blocks
...
...
tests/core/test_chunked_prefill_scheduler.py
View file @
20cfcdec
...
@@ -355,8 +355,8 @@ def test_swap():
...
@@ -355,8 +355,8 @@ def test_swap():
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
0
assert
len
(
out
.
scheduled_seq_groups
)
==
0
assert
out
.
num_batched_tokens
==
0
assert
out
.
num_batched_tokens
==
0
assert
out
.
blocks_to_swap_out
!=
{}
assert
out
.
blocks_to_swap_out
!=
[]
assert
out
.
blocks_to_swap_in
==
{}
assert
out
.
blocks_to_swap_in
==
[]
# Add 1 more task. Swap should be prioritized over new prefill.
# Add 1 more task. Swap should be prioritized over new prefill.
_
,
seq_group
=
create_dummy_prompt
(
"2"
,
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
"2"
,
prompt_length
=
60
)
...
@@ -365,8 +365,8 @@ def test_swap():
...
@@ -365,8 +365,8 @@ def test_swap():
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
len
(
out
.
scheduled_seq_groups
)
==
1
# 3 decodes. It is swapped in.
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
30
assert
out
.
num_batched_tokens
==
30
assert
out
.
blocks_to_swap_in
!=
{}
assert
out
.
blocks_to_swap_in
!=
[]
assert
out
.
blocks_to_swap_out
==
{}
assert
out
.
blocks_to_swap_out
==
[]
def
test_running_prefill_prioritized_over_swap
():
def
test_running_prefill_prioritized_over_swap
():
...
@@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap():
...
@@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap():
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
0
assert
len
(
out
.
scheduled_seq_groups
)
==
0
assert
out
.
num_batched_tokens
==
0
assert
out
.
num_batched_tokens
==
0
assert
out
.
blocks_to_swap_out
!=
{}
assert
out
.
blocks_to_swap_out
!=
[]
assert
out
.
blocks_to_swap_in
==
{}
assert
out
.
blocks_to_swap_in
==
[]
# Add 1 more task. Swap is not possible, so prefill is running.
# Add 1 more task. Swap is not possible, so prefill is running.
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
scheduler
.
block_manager
.
can_swap_in
=
MagicMock
()
...
@@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap():
...
@@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap():
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
len
(
out
.
scheduled_seq_groups
)
==
1
# 3 decodes. It is swapped in.
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
30
assert
out
.
num_batched_tokens
==
30
assert
out
.
blocks_to_swap_in
==
{}
assert
out
.
blocks_to_swap_in
==
[]
assert
out
.
blocks_to_swap_out
==
{}
assert
out
.
blocks_to_swap_out
==
[]
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
# Now although swap is possible, running prefill is prioritized.
# Now although swap is possible, running prefill is prioritized.
...
@@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap():
...
@@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap():
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
len
(
out
.
scheduled_seq_groups
)
==
1
# 3 decodes. It is swapped in.
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
30
assert
out
.
num_batched_tokens
==
30
assert
out
.
blocks_to_swap_in
==
{}
assert
out
.
blocks_to_swap_in
==
[]
assert
out
.
blocks_to_swap_out
==
{}
assert
out
.
blocks_to_swap_out
==
[]
assert
not
seq_group2
.
is_prefill
()
assert
not
seq_group2
.
is_prefill
()
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
append_new_token
(
seq_group2
,
1
)
append_new_token
(
seq_group2
,
1
)
...
@@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap():
...
@@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap():
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
len
(
out
.
scheduled_seq_groups
)
==
1
# 3 decodes. It is swapped in.
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
1
assert
out
.
num_batched_tokens
==
1
assert
out
.
blocks_to_swap_in
==
{}
assert
out
.
blocks_to_swap_in
==
[]
assert
out
.
blocks_to_swap_out
==
{}
assert
out
.
blocks_to_swap_out
==
[]
assert
not
seq_group2
.
is_prefill
()
assert
not
seq_group2
.
is_prefill
()
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seq_group2
append_new_token
(
seq_group2
,
1
)
append_new_token
(
seq_group2
,
1
)
...
@@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap():
...
@@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap():
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
num_batched_tokens
==
30
assert
out
.
num_batched_tokens
==
30
assert
out
.
blocks_to_swap_in
!=
{}
assert
out
.
blocks_to_swap_in
!=
[]
assert
out
.
blocks_to_swap_out
==
{}
assert
out
.
blocks_to_swap_out
==
[]
def
test_chunked_prefill_preempt
():
def
test_chunked_prefill_preempt
():
...
@@ -493,8 +493,8 @@ def test_chunked_prefill_preempt():
...
@@ -493,8 +493,8 @@ def test_chunked_prefill_preempt():
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
0
assert
len
(
out
.
scheduled_seq_groups
)
==
0
assert
out
.
num_batched_tokens
==
0
assert
out
.
num_batched_tokens
==
0
assert
out
.
blocks_to_swap_out
==
{}
assert
out
.
blocks_to_swap_out
==
[]
assert
out
.
blocks_to_swap_in
==
{}
assert
out
.
blocks_to_swap_in
==
[]
# Make sure we can reschedule preempted request.
# Make sure we can reschedule preempted request.
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
_
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
...
...
tests/core/test_scheduler.py
View file @
20cfcdec
...
@@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
...
@@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
2
assert
len
(
out
.
scheduled_seq_groups
)
==
2
assert
out
.
num_batched_tokens
==
2
assert
out
.
num_batched_tokens
==
2
assert
out
.
blocks_to_swap_out
!=
{}
assert
out
.
blocks_to_swap_out
!=
[]
assert
out
.
blocks_to_swap_in
==
{}
assert
out
.
blocks_to_swap_in
==
[]
append_new_token
(
out
,
1
)
append_new_token
(
out
,
1
)
# Add 1 more task. Swap should be prioritized over prefill.
# Add 1 more task. Swap should be prioritized over prefill.
...
@@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
...
@@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
assert
len
(
out
.
scheduled_seq_groups
)
==
3
assert
len
(
out
.
scheduled_seq_groups
)
==
3
# 3 decodes. It is swapped in.
# 3 decodes. It is swapped in.
assert
out
.
num_batched_tokens
==
3
assert
out
.
num_batched_tokens
==
3
assert
out
.
blocks_to_swap_in
!=
{}
assert
out
.
blocks_to_swap_in
!=
[]
assert
out
.
blocks_to_swap_out
==
{}
assert
out
.
blocks_to_swap_out
==
[]
def
initialize_scheduler
(
*
,
def
initialize_scheduler
(
*
,
...
@@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
...
@@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
# assert budget.num_curr_seqs == 1
# assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped.
# Both should be preempted, not swapped.
assert
output
.
blocks_to_swap_out
==
{}
assert
output
.
blocks_to_swap_out
==
[]
# Nothing is copied.
# Nothing is copied.
assert
output
.
blocks_to_copy
==
[]
assert
output
.
blocks_to_copy
==
[]
...
@@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
...
@@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
scheduler
.
block_manager
.
can_append_slots
.
side_effect
=
(
cannot_append_second_group
)
cannot_append_second_group
)
scheduler
.
block_manager
.
swap_out
=
MagicMock
()
scheduler
.
block_manager
.
swap_out
=
MagicMock
()
expected_swap_mapping
=
{
"5"
:
"7"
}
expected_swap_mapping
=
[(
"5"
,
"7"
)]
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
scheduler
.
block_manager
.
swap_out
.
return_value
=
expected_swap_mapping
remainig_running
,
output
=
scheduler
.
_schedule_running
(
remainig_running
,
output
=
scheduler
.
_schedule_running
(
...
@@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
...
@@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert
len
(
output
.
preempted
)
==
0
assert
len
(
output
.
preempted
)
==
0
assert
len
(
output
.
swapped_out
)
==
0
assert
len
(
output
.
swapped_out
)
==
0
# Nothing is preempted.
# Nothing is preempted.
assert
output
.
blocks_to_swap_out
==
{}
assert
output
.
blocks_to_swap_out
==
[]
# Since append_slot returns the source -> dist mapping, it should
# Since append_slot returns the source -> dist mapping, it should
# applied.
# applied.
assert
output
.
blocks_to_copy
==
[(
2
,
3
)]
assert
output
.
blocks_to_copy
==
[(
2
,
3
)]
...
@@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
...
@@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
swapped
=
deque
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
=
{}
blocks_to_swap_out
=
[]
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
...
@@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
...
@@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
decode_seq_groups
)
==
1
assert
len
(
output
.
prefill_seq_groups
)
==
0
assert
len
(
output
.
prefill_seq_groups
)
==
0
# swap in is the reverse of swap out
# swap in is the reverse of swap out
blocks_to_swap_in_reverse
=
{}
blocks_to_swap_in_reverse
=
[]
for
swapin
,
swapout
in
output
.
blocks_to_swap_in
.
items
()
:
for
swapin
,
swapout
in
output
.
blocks_to_swap_in
:
blocks_to_swap_in_reverse
[
swapout
]
=
swapin
blocks_to_swap_in_reverse
.
append
((
swapout
,
swapin
))
assert
blocks_to_swap_out
==
blocks_to_swap_in_reverse
assert
blocks_to_swap_out
==
blocks_to_swap_in_reverse
...
@@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
...
@@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
swapped
=
deque
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
=
{}
blocks_to_swap_out
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
...
@@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
...
@@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
swapped
=
deque
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
=
{}
blocks_to_swap_out
=
[]
for
i
in
range
(
4
):
for
i
in
range
(
4
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
...
@@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
...
@@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
swapped
=
deque
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
set
()
curr_loras
=
set
()
blocks_to_swap_out
=
{}
blocks_to_swap_out
=
[]
for
i
in
range
(
2
):
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
prompt_length
=
60
,
prompt_length
=
60
,
...
@@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
...
@@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
swapped
=
deque
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
=
{}
blocks_to_swap_out
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
...
@@ -808,7 +808,7 @@ def test_infeasible_swap():
...
@@ -808,7 +808,7 @@ def test_infeasible_swap():
swapped
=
deque
()
swapped
=
deque
()
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
policy
=
PolicyFactory
.
get_policy
(
policy_name
=
"fcfs"
)
curr_loras
=
None
curr_loras
=
None
blocks_to_swap_out
=
{}
blocks_to_swap_out
=
[]
for
_
in
range
(
2
):
for
_
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
...
@@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
...
@@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
_
,
seq_group
=
create_dummy_prompt
(
"1"
,
prompt_length
=
60
,
best_of
=
2
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
scheduler
.
_allocate_and_set_running
(
seq_group
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
append_new_token_seq_group
(
60
,
seq_group
,
1
)
blocks_to_swap_out
=
{}
blocks_to_swap_out
=
[]
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
scheduler
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
swapped
.
append
(
seq_group
)
swapped
.
append
(
seq_group
)
...
...
tests/kernels/test_cache.py
View file @
20cfcdec
...
@@ -315,7 +315,10 @@ def test_swap_blocks(
...
@@ -315,7 +315,10 @@ def test_swap_blocks(
else
:
else
:
dst_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
dst_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
block_mapping
=
dict
(
zip
(
src_blocks
,
dst_blocks
))
block_mapping
=
list
(
zip
(
src_blocks
,
dst_blocks
))
block_mapping_tensor
=
torch
.
tensor
(
block_mapping
,
dtype
=
torch
.
int64
,
device
=
"cpu"
).
view
(
-
1
,
2
)
# Create the KV caches on the first device.
# Create the KV caches on the first device.
src_key_caches
,
src_value_caches
=
kv_cache_factory
(
src_key_caches
,
src_value_caches
=
kv_cache_factory
(
...
@@ -331,10 +334,12 @@ def test_swap_blocks(
...
@@ -331,10 +334,12 @@ def test_swap_blocks(
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
# Call the swap_blocks kernel.
# Call the swap_blocks kernel.
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping
)
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping
)
block_mapping_tensor
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
.
items
()
:
for
src
,
dst
in
block_mapping
:
assert
torch
.
allclose
(
src_key_caches_clone
[
src
].
cpu
(),
assert
torch
.
allclose
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
dist_key_caches
[
0
][
dst
].
cpu
())
assert
torch
.
allclose
(
src_value_caches_clone
[
src
].
cpu
(),
assert
torch
.
allclose
(
src_value_caches_clone
[
src
].
cpu
(),
...
...
tests/worker/test_swap.py
View file @
20cfcdec
...
@@ -54,10 +54,10 @@ def test_swap() -> None:
...
@@ -54,10 +54,10 @@ def test_swap() -> None:
a
.
cuda
(),
b
.
cuda
(),
rtol
=
0.0
,
atol
=
0.0
)
a
.
cuda
(),
b
.
cuda
(),
rtol
=
0.0
,
atol
=
0.0
)
# Test swap out.
# Test swap out.
blocks_to_swap_out
=
{
3
:
72
,
56
:
35
,
84
:
34
}
blocks_to_swap_out
=
[(
3
,
72
)
,
(
56
,
35
)
,
(
84
,
34
)]
execute_model_req
=
ExecuteModelRequest
(
execute_model_req
=
ExecuteModelRequest
(
seq_group_metadata_list
=
[],
seq_group_metadata_list
=
[],
blocks_to_swap_in
=
{}
,
blocks_to_swap_in
=
[]
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_swap_out
=
blocks_to_swap_out
,
blocks_to_copy
=
[],
blocks_to_copy
=
[],
)
)
...
@@ -66,24 +66,24 @@ def test_swap() -> None:
...
@@ -66,24 +66,24 @@ def test_swap() -> None:
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
blocks_to_swap_out
.
items
()
:
for
src
,
dst
in
blocks_to_swap_out
:
assert
allclose
(
gpu_key_cache
[
src
],
cpu_key_cache
[
dst
])
assert
allclose
(
gpu_key_cache
[
src
],
cpu_key_cache
[
dst
])
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
assert
allclose
(
gpu_value_cache
[
src
],
cpu_value_cache
[
dst
])
# Test swap in.
# Test swap in.
execute_model_req
.
blocks_to_swap_out
=
{}
execute_model_req
.
blocks_to_swap_out
=
[]
execute_model_req
.
blocks_to_swap_in
=
{
execute_model_req
.
blocks_to_swap_in
=
[
19
:
45
,
(
19
,
45
)
,
67
:
23
,
(
67
,
23
)
,
12
:
78
,
(
12
,
78
)
,
40
:
99
,
(
40
,
99
)
,
1
:
71
(
1
,
71
),
}
]
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
worker
.
execute_model
(
execute_model_req
=
execute_model_req
)
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
gpu_key_cache
,
gpu_value_cache
=
gpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
cpu_key_cache
,
cpu_value_cache
=
cpu_cache
[
i
]
for
src
,
dst
in
execute_model_req
.
blocks_to_swap_in
.
items
()
:
for
src
,
dst
in
execute_model_req
.
blocks_to_swap_in
:
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_key_cache
[
dst
],
cpu_key_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
assert
allclose
(
gpu_value_cache
[
dst
],
cpu_value_cache
[
src
])
vllm/attention/backends/abstract.py
View file @
20cfcdec
...
@@ -39,7 +39,7 @@ class AttentionBackend(ABC):
...
@@ -39,7 +39,7 @@ class AttentionBackend(ABC):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/flash_attn.py
View file @
20cfcdec
...
@@ -5,7 +5,7 @@ XFormers backend. The duplicated code will be removed once we use flash-attn or
...
@@ -5,7 +5,7 @@ XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
flashinfer for all the attention operations.
"""
"""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
flash_attn
import
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
...
@@ -45,7 +45,7 @@ class FlashAttentionBackend(AttentionBackend):
...
@@ -45,7 +45,7 @@ class FlashAttentionBackend(AttentionBackend):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
...
...
vllm/attention/backends/flashinfer.py
View file @
20cfcdec
...
@@ -39,7 +39,7 @@ class FlashInferBackend(AttentionBackend):
...
@@ -39,7 +39,7 @@ class FlashInferBackend(AttentionBackend):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
20cfcdec
"""Attention layer ROCm GPUs."""
"""Attention layer ROCm GPUs."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -43,7 +43,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
...
@@ -43,7 +43,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
20cfcdec
""" Attention layer with torch scaled_dot_product_attention
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
and PagedAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
torch.nn.functional
import
scaled_dot_product_attention
...
@@ -41,7 +41,7 @@ class TorchSDPABackend(AttentionBackend):
...
@@ -41,7 +41,7 @@ class TorchSDPABackend(AttentionBackend):
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
PagedAttention
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
...
...
vllm/attention/ops/paged_attn.py
View file @
20cfcdec
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -196,7 +196,7 @@ class PagedAttention:
...
@@ -196,7 +196,7 @@ class PagedAttention:
def
swap_blocks
(
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
]
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
...
...
vllm/core/block_manager_v1.py
View file @
20cfcdec
...
@@ -473,11 +473,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -473,11 +473,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def
swap_in
(
self
,
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
=
0
)
->
Dict
[
int
,
int
]:
num_lookahead_slots
:
int
=
0
)
->
List
[
Tuple
[
int
,
int
]
]
:
assert
(
num_lookahead_slots
==
0
assert
(
num_lookahead_slots
==
0
),
"BlockSpaceManagerV1 does not support lookahead allocation"
),
"BlockSpaceManagerV1 does not support lookahead allocation"
# CPU block -> GPU block.
# CPU block -> GPU block.
# dict is efficient in lookup `if cpu_block in mapping`
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
...
@@ -500,14 +501,16 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -500,14 +501,16 @@ class BlockSpaceManagerV1(BlockSpaceManager):
cpu_block
.
block_number
:
gpu_block
.
block_number
cpu_block
.
block_number
:
gpu_block
.
block_number
for
cpu_block
,
gpu_block
in
mapping
.
items
()
for
cpu_block
,
gpu_block
in
mapping
.
items
()
}
}
return
block_number_mapping
# convert to list of tuples once here
return
list
(
block_number_mapping
.
items
())
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
blocks
=
self
.
_get_physical_blocks
(
seq_group
)
return
len
(
blocks
)
<=
self
.
cpu_allocator
.
get_num_free_blocks
()
return
len
(
blocks
)
<=
self
.
cpu_allocator
.
get_num_free_blocks
()
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]
]
:
# GPU block -> CPU block.
# GPU block -> CPU block.
# dict is efficient in lookup `if gpu_block in mapping`
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
...
@@ -530,7 +533,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -530,7 +533,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
gpu_block
.
block_number
:
cpu_block
.
block_number
gpu_block
.
block_number
:
cpu_block
.
block_number
for
gpu_block
,
cpu_block
in
mapping
.
items
()
for
gpu_block
,
cpu_block
in
mapping
.
items
()
}
}
return
block_number_mapping
# convert to list of tuples once here
return
list
(
block_number_mapping
.
items
())
def
_free_block_table
(
self
,
block_table
:
BlockTable
)
->
None
:
def
_free_block_table
(
self
,
block_table
:
BlockTable
)
->
None
:
# when using a sliding window, each seq will only use up
# when using a sliding window, each seq will only use up
...
...
vllm/core/block_manager_v2.py
View file @
20cfcdec
...
@@ -243,13 +243,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -243,13 +243,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return
AllocStatus
.
LATER
return
AllocStatus
.
LATER
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
num_lookahead_slots
:
int
)
->
List
[
Tuple
[
int
,
int
]
]
:
raise
NotImplementedError
raise
NotImplementedError
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
def
can_swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
bool
:
return
False
return
False
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]
]
:
raise
NotImplementedError
raise
NotImplementedError
def
get_num_free_gpu_blocks
(
self
)
->
int
:
def
get_num_free_gpu_blocks
(
self
)
->
int
:
...
...
vllm/core/interfaces.py
View file @
20cfcdec
import
enum
import
enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Dict
,
List
from
typing
import
List
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
typing
import
Tuple
...
@@ -69,7 +69,7 @@ class BlockSpaceManager(ABC):
...
@@ -69,7 +69,7 @@ class BlockSpaceManager(ABC):
@
abstractmethod
@
abstractmethod
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
swap_in
(
self
,
seq_group
:
SequenceGroup
,
num_lookahead_slots
:
int
)
->
Dict
[
int
,
int
]:
num_lookahead_slots
:
int
)
->
List
[
Tuple
[
int
,
int
]
]
:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -77,7 +77,7 @@ class BlockSpaceManager(ABC):
...
@@ -77,7 +77,7 @@ class BlockSpaceManager(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
Tuple
[
int
,
int
]
]
:
pass
pass
@
abstractmethod
@
abstractmethod
...
...
vllm/core/scheduler.py
View file @
20cfcdec
...
@@ -117,10 +117,10 @@ class SchedulerOutputs:
...
@@ -117,10 +117,10 @@ class SchedulerOutputs:
num_prefill_groups
:
int
num_prefill_groups
:
int
# Total number of batched tokens.
# Total number of batched tokens.
num_batched_tokens
:
int
num_batched_tokens
:
int
# Blocks to swap in.
Dic
t of CPU -> GPU block number.
# Blocks to swap in.
Lis
t of CPU -> GPU block number.
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
# Blocks to swap out.
Dic
t of GPU -> CPU block number.
# Blocks to swap out.
Lis
t of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
# Blocks to copy. Source to dest block.
# Blocks to copy. Source to dest block.
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# Sequence groups that are going to be ignored.
# Sequence groups that are going to be ignored.
...
@@ -174,7 +174,7 @@ class SchedulerRunningOutputs:
...
@@ -174,7 +174,7 @@ class SchedulerRunningOutputs:
# Sequences that are swapped out.
# Sequences that are swapped out.
swapped_out
:
List
[
SequenceGroup
]
swapped_out
:
List
[
SequenceGroup
]
# The blocks to swap out.
# The blocks to swap out.
blocks_to_swap_out
:
Dict
[
int
,
int
]
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
# The blocks to copy.
# The blocks to copy.
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
...
@@ -187,7 +187,7 @@ class SchedulerRunningOutputs:
...
@@ -187,7 +187,7 @@ class SchedulerRunningOutputs:
prefill_seq_groups
=
[],
prefill_seq_groups
=
[],
preempted
=
[],
preempted
=
[],
swapped_out
=
[],
swapped_out
=
[],
blocks_to_swap_out
=
{}
,
blocks_to_swap_out
=
[]
,
blocks_to_copy
=
[],
blocks_to_copy
=
[],
num_lookahead_slots
=
0
,
num_lookahead_slots
=
0
,
)
)
...
@@ -206,7 +206,7 @@ class SchedulerSwappedInOutputs:
...
@@ -206,7 +206,7 @@ class SchedulerSwappedInOutputs:
# phase. I.e., it means the prefill has been chunked.
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups
:
List
[
SequenceGroup
]
prefill_seq_groups
:
List
[
SequenceGroup
]
# The blocks to swap in.
# The blocks to swap in.
blocks_to_swap_in
:
Dict
[
int
,
int
]
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
# The blocks to copy.
# The blocks to copy.
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
...
@@ -219,7 +219,7 @@ class SchedulerSwappedInOutputs:
...
@@ -219,7 +219,7 @@ class SchedulerSwappedInOutputs:
return
SchedulerSwappedInOutputs
(
return
SchedulerSwappedInOutputs
(
decode_seq_groups
=
[],
decode_seq_groups
=
[],
prefill_seq_groups
=
[],
prefill_seq_groups
=
[],
blocks_to_swap_in
=
{}
,
blocks_to_swap_in
=
[]
,
blocks_to_copy
=
[],
blocks_to_copy
=
[],
num_lookahead_slots
=
0
,
num_lookahead_slots
=
0
,
infeasible_seq_groups
=
[],
infeasible_seq_groups
=
[],
...
@@ -392,7 +392,7 @@ class Scheduler:
...
@@ -392,7 +392,7 @@ class Scheduler:
scheduling and SchedulerRunningOutputs.
scheduling and SchedulerRunningOutputs.
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
=
[]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
...
@@ -509,7 +509,7 @@ class Scheduler:
...
@@ -509,7 +509,7 @@ class Scheduler:
SchedulerSwappedInOutputs.
SchedulerSwappedInOutputs.
"""
"""
# Blocks that need to be swapped or copied before model execution.
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
{}
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
=
[]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
decode_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
prefill_seq_groups
:
List
[
ScheduledSequenceGroup
]
=
[]
...
@@ -1032,7 +1032,7 @@ class Scheduler:
...
@@ -1032,7 +1032,7 @@ class Scheduler:
def
_preempt
(
def
_preempt
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
,
preemption_mode
:
Optional
[
PreemptionMode
]
=
None
,
preemption_mode
:
Optional
[
PreemptionMode
]
=
None
,
)
->
PreemptionMode
:
)
->
PreemptionMode
:
# If preemption mode is not specified, we determine the mode as follows:
# If preemption mode is not specified, we determine the mode as follows:
...
@@ -1073,24 +1073,24 @@ class Scheduler:
...
@@ -1073,24 +1073,24 @@ class Scheduler:
def
_preempt_by_swap
(
def
_preempt_by_swap
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
,
)
->
None
:
)
->
None
:
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
self
.
_swap_out
(
seq_group
,
blocks_to_swap_out
)
def
_swap_in
(
def
_swap_in
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
,
)
->
None
:
)
->
None
:
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_in
(
seq_group
)
blocks_to_swap_in
.
update
(
mapping
)
blocks_to_swap_in
.
extend
(
mapping
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
seq
.
status
=
SequenceStatus
.
RUNNING
seq
.
status
=
SequenceStatus
.
RUNNING
def
_swap_out
(
def
_swap_out
(
self
,
self
,
seq_group
:
SequenceGroup
,
seq_group
:
SequenceGroup
,
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
,
)
->
None
:
)
->
None
:
if
not
self
.
block_manager
.
can_swap_out
(
seq_group
):
if
not
self
.
block_manager
.
can_swap_out
(
seq_group
):
# FIXME(woosuk): Abort the sequence group instead of aborting the
# FIXME(woosuk): Abort the sequence group instead of aborting the
...
@@ -1099,7 +1099,7 @@ class Scheduler:
...
@@ -1099,7 +1099,7 @@ class Scheduler:
"Aborted due to the lack of CPU swap space. Please increase "
"Aborted due to the lack of CPU swap space. Please increase "
"the swap space to avoid this error."
)
"the swap space to avoid this error."
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
mapping
=
self
.
block_manager
.
swap_out
(
seq_group
)
blocks_to_swap_out
.
update
(
mapping
)
blocks_to_swap_out
.
extend
(
mapping
)
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
seq
.
status
=
SequenceStatus
.
SWAPPED
seq
.
status
=
SequenceStatus
.
SWAPPED
...
...
vllm/sequence.py
View file @
20cfcdec
...
@@ -741,10 +741,10 @@ class ExecuteModelRequest:
...
@@ -741,10 +741,10 @@ class ExecuteModelRequest:
"""The model execution request."""
"""The model execution request."""
# The sequence group metadata list.
# The sequence group metadata list.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
# Blocks to swap in.
Dic
t of CPU -> GPU block number.
# Blocks to swap in.
Lis
t of CPU -> GPU block number.
blocks_to_swap_in
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dic
t
)
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]
]
=
field
(
default_factory
=
lis
t
)
# Blocks to swap out.
Dic
t of GPU -> CPU block number.
# Blocks to swap out.
Lis
t of GPU -> CPU block number.
blocks_to_swap_out
:
Dict
[
int
,
int
]
=
field
(
default_factory
=
dic
t
)
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]
]
=
field
(
default_factory
=
lis
t
)
# Blocks to copy. Source to dest block.
# Blocks to copy. Source to dest block.
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
field
(
default_factory
=
list
)
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
field
(
default_factory
=
list
)
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
...
...
vllm/worker/cache_engine.py
View file @
20cfcdec
"""CacheEngine class for managing the KV cache."""
"""CacheEngine class for managing the KV cache."""
from
typing
import
Dict
,
List
from
typing
import
List
import
torch
import
torch
...
@@ -67,12 +67,12 @@ class CacheEngine:
...
@@ -67,12 +67,12 @@ class CacheEngine:
device
=
device
))
device
=
device
))
return
kv_cache
return
kv_cache
def
swap_in
(
self
,
src_to_dst
:
Dict
[
int
,
int
]
)
->
None
:
def
swap_in
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
self
.
attn_backend
.
swap_blocks
(
self
.
cpu_cache
[
i
],
self
.
gpu_cache
[
i
],
self
.
attn_backend
.
swap_blocks
(
self
.
cpu_cache
[
i
],
self
.
gpu_cache
[
i
],
src_to_dst
)
src_to_dst
)
def
swap_out
(
self
,
src_to_dst
:
Dict
[
int
,
int
]
)
->
None
:
def
swap_out
(
self
,
src_to_dst
:
torch
.
Tensor
)
->
None
:
for
i
in
range
(
self
.
num_layers
):
for
i
in
range
(
self
.
num_layers
):
self
.
attn_backend
.
swap_blocks
(
self
.
gpu_cache
[
i
],
self
.
cpu_cache
[
i
],
self
.
attn_backend
.
swap_blocks
(
self
.
gpu_cache
[
i
],
self
.
cpu_cache
[
i
],
src_to_dst
)
src_to_dst
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment