Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7ac67ea5
Unverified
Commit
7ac67ea5
authored
Sep 19, 2025
by
Or Ozeri
Committed by
GitHub
Sep 19, 2025
Browse files
[KV offload][3/N] Add worker-side CPU support (#21448)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
ce75e153
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
348 additions
and
0 deletions
+348
-0
tests/v1/kv_offload/test_cpu_gpu.py
tests/v1/kv_offload/test_cpu_gpu.py
+177
-0
vllm/v1/kv_offload/worker/cpu_gpu.py
vllm/v1/kv_offload/worker/cpu_gpu.py
+171
-0
No files found.
tests/v1/kv_offload/test_cpu_gpu.py
0 → 100644
View file @
7ac67ea5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
random
import
time
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.v1.attention.backends.mla.flashattn_mla
import
FlashAttnMLABackend
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.worker.cpu_gpu
import
CpuGpuOffloadingHandler
NUM_GPU_BLOCKS
=
[
64
]
NUM_CPU_BLOCKS
=
[
256
]
GPU_BLOCK_SIZES
=
[
16
]
GPU_BLOCKS_PER_CPU_BLOCK
=
[
1
,
3
]
HEAD_SIZES
=
[
64
]
NUM_HEADS
=
[
8
]
NUM_LAYERS
=
[
4
]
DTYPES
=
[
torch
.
bfloat16
]
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
'cuda:0'
]
NUM_MAPPINGS
=
[
3
]
@
pytest
.
mark
.
parametrize
(
"gpu_to_cpu"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"gpu_block_size"
,
GPU_BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"gpu_blocks_per_cpu_block"
,
GPU_BLOCKS_PER_CPU_BLOCK
)
@
pytest
.
mark
.
parametrize
(
"num_gpu_blocks"
,
NUM_GPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_cpu_blocks"
,
NUM_CPU_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"num_layers"
,
NUM_LAYERS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_transfer
(
gpu_to_cpu
:
bool
,
num_mappings
:
int
,
head_size
:
int
,
num_heads
:
int
,
gpu_block_size
:
int
,
gpu_blocks_per_cpu_block
:
int
,
num_gpu_blocks
:
int
,
num_cpu_blocks
:
int
,
num_layers
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
# create per-layer GPU KV caches
attn_backends_list
=
[
FlashAttentionBackend
,
FlashInferBackend
,
FlashAttnMLABackend
]
gpu_caches
=
{}
attn_backends
=
{}
for
i
in
range
(
num_layers
):
layer_name
=
f
'layer
{
i
}
'
attn_backend
=
attn_backends_list
[
i
%
len
(
attn_backends_list
)]
attn_backends
[
layer_name
]
=
attn_backend
gpu_cache_shape
=
attn_backend
.
get_kv_cache_shape
(
num_gpu_blocks
,
gpu_block_size
,
num_heads
,
head_size
)
gpu_caches
[
layer_name
]
=
torch
.
rand
(
gpu_cache_shape
,
dtype
=
dtype
,
device
=
device
)
# create handler
cpu_block_size
=
gpu_blocks_per_cpu_block
*
gpu_block_size
handler
=
CpuGpuOffloadingHandler
(
attn_backends
=
attn_backends
,
gpu_block_size
=
gpu_block_size
,
cpu_block_size
=
cpu_block_size
,
num_cpu_blocks
=
num_cpu_blocks
,
gpu_caches
=
gpu_caches
)
# select block mappings
gpu_blocks
=
random
.
sample
(
range
(
num_gpu_blocks
),
num_mappings
*
gpu_blocks_per_cpu_block
)
cpu_blocks
=
random
.
sample
(
range
(
num_cpu_blocks
),
num_mappings
)
# convert cpu blocks to gpu block size
cpu_blocks_in_gpu_block_size
=
[]
for
cpu_block
in
cpu_blocks
:
base_block_id
=
cpu_block
*
gpu_blocks_per_cpu_block
for
i
in
range
(
gpu_blocks_per_cpu_block
):
cpu_blocks_in_gpu_block_size
.
append
(
i
+
base_block_id
)
# maybe skip a GPU block to test writing to the middle of a CPU block
if
gpu_to_cpu
:
gpu_blocks
=
gpu_blocks
[
gpu_blocks_per_cpu_block
-
1
:]
cpu_blocks_in_gpu_block_size
=
cpu_blocks_in_gpu_block_size
[
gpu_blocks_per_cpu_block
-
1
:]
# set transfer direction
if
gpu_to_cpu
:
src_kv_caches
=
handler
.
gpu_tensors
dst_kv_caches
=
handler
.
cpu_tensors
src_spec_class
=
GPULoadStoreSpec
dst_spec_class
=
CPULoadStoreSpec
src_blocks
=
gpu_blocks
dst_blocks
=
cpu_blocks
src_blocks_in_gpu_block_size
=
gpu_blocks
dst_blocks_in_gpu_block_size
=
cpu_blocks_in_gpu_block_size
dst_size_in_gpu_blocks
=
num_cpu_blocks
*
gpu_blocks_per_cpu_block
else
:
src_kv_caches
=
handler
.
cpu_tensors
dst_kv_caches
=
handler
.
gpu_tensors
src_spec_class
=
CPULoadStoreSpec
dst_spec_class
=
GPULoadStoreSpec
src_blocks
=
cpu_blocks
dst_blocks
=
gpu_blocks
src_blocks_in_gpu_block_size
=
cpu_blocks_in_gpu_block_size
dst_blocks_in_gpu_block_size
=
gpu_blocks
dst_size_in_gpu_blocks
=
num_gpu_blocks
# build dst -> src mapping
dst_to_src
=
{}
for
src_block
,
dst_block
in
zip
(
src_blocks_in_gpu_block_size
,
dst_blocks_in_gpu_block_size
):
dst_to_src
[
dst_block
]
=
src_block
# build transfer specs
src_spec
=
src_spec_class
(
src_blocks
)
dst_spec
=
dst_spec_class
(
dst_blocks
)
# clone src and dst tensors before transfer
orig_src_caches
=
[
x
.
clone
()
for
x
in
src_kv_caches
]
orig_dst_caches
=
[
x
.
clone
()
for
x
in
dst_kv_caches
]
# call transfer function
assert
handler
.
transfer_async
(
1
,
(
src_spec
,
dst_spec
))
assert
set
(
handler
.
transfer_events
.
keys
())
==
{
1
}
# wait for transfer to complete
end_time
=
time
.
time
()
+
10
while
time
.
time
()
<
end_time
:
finished
=
handler
.
get_finished
()
if
finished
:
assert
finished
==
[(
1
,
True
)]
break
time
.
sleep
(
0.1
)
# verify src tensors did not change
for
orig_tensor
,
tensor
in
zip
(
orig_src_caches
,
src_kv_caches
):
assert
torch
.
equal
(
orig_tensor
,
tensor
)
# verify dst tensors
for
dst_block
in
range
(
dst_size_in_gpu_blocks
):
src_block_candidate
=
dst_to_src
.
get
(
dst_block
)
for
src_cache
,
dst_cache
,
orig_dst_cache
,
kv_dim
in
zip
(
src_kv_caches
,
dst_kv_caches
,
orig_dst_caches
,
handler
.
kv_dim_before_num_blocks
):
if
kv_dim
:
# iterate over key, value
for
i
in
range
(
2
):
if
src_block_candidate
is
not
None
:
expected_value
=
src_cache
[
i
][
src_block_candidate
]
else
:
expected_value
=
orig_dst_cache
[
i
][
dst_block
]
torch
.
testing
.
assert_close
(
dst_cache
[
i
][
dst_block
].
cpu
(),
expected_value
.
cpu
())
else
:
if
src_block_candidate
is
not
None
:
expected_value
=
src_cache
[
src_block_candidate
]
else
:
expected_value
=
orig_dst_cache
[
dst_block
]
torch
.
testing
.
assert_close
(
dst_cache
[
dst_block
].
cpu
(),
expected_value
.
cpu
())
vllm/v1/kv_offload/worker/cpu_gpu.py
0 → 100644
View file @
7ac67ea5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
numpy
as
np
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_pin_memory_available
from
vllm.v1.kv_offload.mediums
import
CPULoadStoreSpec
,
GPULoadStoreSpec
from
vllm.v1.kv_offload.worker.worker
import
(
OffloadingHandler
,
TransferResult
,
TransferSpec
)
logger
=
init_logger
(
__name__
)
def
expand_block_ids
(
block_ids
:
np
.
ndarray
,
block_size_factor
:
int
,
output
:
np
.
ndarray
,
skip_count
:
int
=
0
):
"""
Convert a list of block IDs to a list of matching block ids,
assuming each block is composed of actual block_size_factor blocks.
Outputs to output tensor.
The first skip_count blocks will be skipped.
Note that skip_count must be less than block_size_factor.
For example, if block_ids = [0, 1, 3] and block_size_factor = 4,
then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
since 0 maps to [0, 1, 2, 3]
1 maps to [4, 5, 6, 7]
and 3 maps to [12, 13, 14, 15]
"""
assert
skip_count
<
block_size_factor
first_range
=
np
.
arange
(
skip_count
,
block_size_factor
)
full_range
=
np
.
arange
(
0
,
block_size_factor
)
output_idx
=
0
for
i
,
block_id
in
enumerate
(
block_ids
):
base_block_id
=
block_id
*
block_size_factor
indices
=
first_range
if
i
==
0
else
full_range
output_end_idx
=
output_idx
+
len
(
indices
)
output
[
output_idx
:
output_end_idx
]
=
base_block_id
+
indices
output_idx
=
output_end_idx
class
CpuGpuOffloadingHandler
(
OffloadingHandler
):
def
__init__
(
self
,
gpu_block_size
:
int
,
cpu_block_size
:
int
,
num_cpu_blocks
:
int
,
gpu_caches
:
dict
[
str
,
torch
.
Tensor
],
attn_backends
:
dict
[
str
,
type
[
AttentionBackend
]]):
assert
cpu_block_size
%
gpu_block_size
==
0
self
.
block_size_factor
=
cpu_block_size
//
gpu_block_size
# cuda streams for gpu->cpu and cpu->gpu
self
.
d2h_stream
=
torch
.
cuda
.
Stream
()
self
.
h2d_stream
=
torch
.
cuda
.
Stream
()
# job_id -> transfer cuda event
self
.
transfer_events
:
dict
[
int
,
torch
.
cuda
.
Event
]
=
{}
# list of cuda events available for re-use
self
.
events_pool
:
list
[
torch
.
cuda
.
Event
]
=
[]
pin_memory
=
is_pin_memory_available
()
# allocate cpu tensors
logger
.
info
(
"Allocating %d CPU tensors..."
,
len
(
gpu_caches
))
self
.
gpu_tensors
:
list
[
torch
.
Tensor
]
=
[]
self
.
cpu_tensors
:
list
[
torch
.
Tensor
]
=
[]
self
.
kv_dim_before_num_blocks
:
list
[
bool
]
=
[]
for
layer_name
,
gpu_tensor
in
gpu_caches
.
items
():
self
.
gpu_tensors
.
append
(
gpu_tensor
)
gpu_shape
=
gpu_tensor
.
shape
test_shape
=
attn_backends
[
layer_name
].
get_kv_cache_shape
(
num_blocks
=
1234
,
block_size
=
16
,
num_kv_heads
=
8
,
head_size
=
256
)
if
test_shape
[
0
]
==
1234
:
# shape is (num_blocks, ...)
num_blocks_idx
=
0
self
.
kv_dim_before_num_blocks
.
append
(
False
)
else
:
# shape should be (2, num_blocks, ...)
assert
test_shape
[
0
]
==
2
assert
test_shape
[
1
]
==
1234
assert
gpu_shape
[
0
]
==
2
num_blocks_idx
=
1
self
.
kv_dim_before_num_blocks
.
append
(
True
)
cpu_shape
=
list
(
gpu_shape
)
cpu_shape
[
num_blocks_idx
]
=
num_cpu_blocks
*
self
.
block_size_factor
logger
.
debug
(
"Allocating CPU tensor of shape %r"
,
cpu_shape
)
self
.
cpu_tensors
.
append
(
torch
.
zeros
(
cpu_shape
,
dtype
=
gpu_tensor
.
dtype
,
device
=
"cpu"
,
pin_memory
=
pin_memory
))
def
transfer_async
(
self
,
job_id
:
int
,
spec
:
TransferSpec
)
->
bool
:
src_spec
,
dst_spec
=
spec
if
isinstance
(
src_spec
,
CPULoadStoreSpec
):
assert
isinstance
(
dst_spec
,
GPULoadStoreSpec
)
stream
=
self
.
h2d_stream
src_tensors
=
self
.
cpu_tensors
dst_tensors
=
self
.
gpu_tensors
src_block_size_factor
=
self
.
block_size_factor
dst_block_size_factor
=
1
else
:
assert
isinstance
(
src_spec
,
GPULoadStoreSpec
)
assert
isinstance
(
dst_spec
,
CPULoadStoreSpec
)
stream
=
self
.
d2h_stream
src_tensors
=
self
.
gpu_tensors
dst_tensors
=
self
.
cpu_tensors
src_block_size_factor
=
1
dst_block_size_factor
=
self
.
block_size_factor
src_blocks
=
src_spec
.
block_ids
dst_blocks
=
dst_spec
.
block_ids
assert
src_blocks
.
ndim
==
1
assert
dst_blocks
.
ndim
==
1
dst_sub_blocks_to_skip
=
(
-
src_blocks
.
size
%
dst_block_size_factor
)
src_sub_block_count
=
src_blocks
.
size
*
src_block_size_factor
assert
(
src_sub_block_count
==
dst_blocks
.
size
*
dst_block_size_factor
-
dst_sub_blocks_to_skip
)
src_to_dst
=
np
.
empty
((
src_sub_block_count
,
2
),
dtype
=
np
.
int64
)
expand_block_ids
(
src_blocks
,
src_block_size_factor
,
src_to_dst
[:,
0
])
expand_block_ids
(
dst_blocks
,
dst_block_size_factor
,
src_to_dst
[:,
1
],
skip_count
=
dst_sub_blocks_to_skip
)
src_to_dst_tensor
=
torch
.
from_numpy
(
src_to_dst
)
event
=
self
.
events_pool
.
pop
()
if
self
.
events_pool
\
else
torch
.
cuda
.
Event
()
with
torch
.
cuda
.
stream
(
stream
):
for
src_tensor
,
dst_tensor
,
kv_dim
in
zip
(
src_tensors
,
dst_tensors
,
self
.
kv_dim_before_num_blocks
):
if
kv_dim
:
src_key_cache
=
src_tensor
[
0
]
dst_key_cache
=
dst_tensor
[
0
]
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst_tensor
)
src_value_cache
=
src_tensor
[
1
]
dst_value_cache
=
dst_tensor
[
1
]
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst_tensor
)
else
:
ops
.
swap_blocks
(
src_tensor
,
dst_tensor
,
src_to_dst_tensor
)
event
.
record
(
stream
)
self
.
transfer_events
[
job_id
]
=
event
# success
return
True
def
get_finished
(
self
)
->
list
[
TransferResult
]:
results
:
list
[
TransferResult
]
=
[]
for
job_id
,
event
in
self
.
transfer_events
.
items
():
if
event
.
query
():
results
.
append
((
job_id
,
True
))
self
.
events_pool
.
append
(
event
)
for
job_id
,
_
in
results
:
del
self
.
transfer_events
[
job_id
]
return
results
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment