Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cfaf4668
Unverified
Commit
cfaf4668
authored
Mar 13, 2026
by
Or Ozeri
Committed by
GitHub
Mar 13, 2026
Browse files
[kv_offload+HMA][1/N]: Support multiple KV groups in OffloadingSpec (#36610)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
99a57bdf
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
22 deletions
+80
-22
tests/v1/kv_connector/unit/test_offloading_connector.py
tests/v1/kv_connector/unit/test_offloading_connector.py
+37
-5
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+5
-3
vllm/v1/kv_offload/cpu.py
vllm/v1/kv_offload/cpu.py
+10
-6
vllm/v1/kv_offload/factory.py
vllm/v1/kv_offload/factory.py
+1
-1
vllm/v1/kv_offload/spec.py
vllm/v1/kv_offload/spec.py
+27
-7
No files found.
tests/v1/kv_connector/unit/test_offloading_connector.py
View file @
cfaf4668
...
...
@@ -26,8 +26,13 @@ from vllm.v1.core.kv_cache_utils import (
get_request_block_hasher
,
init_none_hash
,
)
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
)
from
vllm.v1.kv_offload.abstract
import
(
LoadStoreSpec
,
OffloadingEvent
,
...
...
@@ -43,11 +48,11 @@ from vllm.v1.kv_offload.worker.worker import (
)
from
vllm.v1.outputs
import
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.structured_output
import
StructuredOutputManager
from
.utils
import
(
EOS_TOKEN_ID
,
create_model_runner_output
,
create_scheduler
,
create_vllm_config
,
)
...
...
@@ -175,10 +180,37 @@ class RequestRunner:
},
)
self
.
scheduler
:
Scheduler
=
create_scheduler
(
vllm_config
,
num_blocks
=
num_gpu_blocks
block_size
=
vllm_config
.
cache_config
.
block_size
kv_cache_config
=
KVCacheConfig
(
num_blocks
=
num_gpu_blocks
,
kv_cache_tensors
=
[],
kv_cache_groups
=
[
KVCacheGroupSpec
(
[
"layer"
],
FullAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
1
,
head_size
=
1
,
dtype
=
torch
.
float32
,
),
)
],
)
vllm_config
.
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
self
.
num_kv_groups
=
len
(
kv_cache_config
.
kv_cache_groups
)
scheduler_cls
=
AsyncScheduler
if
async_scheduling
else
Scheduler
self
.
scheduler
=
scheduler_cls
(
vllm_config
=
vllm_config
,
kv_cache_config
=
kv_cache_config
,
log_stats
=
True
,
structured_output_manager
=
StructuredOutputManager
(
vllm_config
),
block_size
=
block_size
,
)
self
.
worker_connector
=
OffloadingConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
,
kv_cache_config
)
self
.
worker_connector
=
OffloadingConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
# register worker kv_caches to enable OffloadingWorker creations
self
.
worker_connector
.
register_cross_layers_kv_cache
(
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
cfaf4668
...
...
@@ -126,6 +126,7 @@ class OffloadingConnector(KVConnectorBase_V1):
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
assert
kv_cache_config
is
not
None
spec
=
OffloadingSpecFactory
.
create_spec
(
vllm_config
,
kv_cache_config
)
self
.
connector_scheduler
:
OffloadingConnectorScheduler
|
None
=
None
...
...
@@ -245,9 +246,10 @@ class OffloadingConnectorScheduler:
"""Implementation of Scheduler side methods"""
def
__init__
(
self
,
spec
:
OffloadingSpec
):
self
.
gpu_block_size
=
spec
.
gpu_block_size
self
.
offloaded_block_size
=
spec
.
offloaded_block_size
self
.
block_size_factor
=
self
.
offloaded_block_size
//
self
.
gpu_block_size
assert
len
(
spec
.
gpu_block_size
)
==
1
self
.
gpu_block_size
=
spec
.
gpu_block_size
[
0
]
self
.
offloaded_block_size
=
self
.
gpu_block_size
*
spec
.
block_size_factor
self
.
block_size_factor
=
spec
.
block_size_factor
self
.
manager
:
OffloadingManager
=
spec
.
get_manager
()
self
.
_requests
:
dict
[
ReqId
,
Request
]
=
{}
...
...
vllm/v1/kv_offload/cpu.py
View file @
cfaf4668
...
...
@@ -42,10 +42,8 @@ class CPUOffloadingSpec(OffloadingSpec):
*
len
(
kv_cache_config
.
kv_cache_tensors
)
*
vllm_config
.
parallel_config
.
world_size
)
kv_bytes_per_offloaded_block
=
kv_bytes_per_block
*
(
self
.
offloaded_block_size
//
self
.
gpu_block_size
)
kv_bytes_per_offloaded_block
=
kv_bytes_per_block
*
self
.
block_size_factor
self
.
num_blocks
=
(
int
(
cpu_bytes_to_use
)
//
kv_bytes_per_offloaded_block
if
kv_bytes_per_offloaded_block
>
0
...
...
@@ -67,8 +65,11 @@ class CPUOffloadingSpec(OffloadingSpec):
kv_events_config
is
not
None
and
kv_events_config
.
enable_kv_cache_events
)
assert
len
(
self
.
gpu_block_size
)
==
1
gpu_block_size
=
self
.
gpu_block_size
[
0
]
offloaded_block_size
=
gpu_block_size
*
self
.
block_size_factor
backend
=
CPUBackend
(
block_size
=
self
.
offloaded_block_size
,
num_blocks
=
self
.
num_blocks
block_size
=
offloaded_block_size
,
num_blocks
=
self
.
num_blocks
)
if
self
.
eviction_policy
==
"lru"
:
...
...
@@ -111,10 +112,13 @@ class CPUOffloadingSpec(OffloadingSpec):
"CPU Offloading is currently only supported on CUDA-alike GPUs"
)
assert
len
(
self
.
gpu_block_size
)
==
1
gpu_block_size
=
self
.
gpu_block_size
[
0
]
self
.
_handlers
=
CpuGpuOffloadingHandlers
(
attn_backends
=
attn_backends
,
gpu_block_size
=
self
.
gpu_block_size
,
cpu_block_size
=
self
.
offloaded_
block_size
,
gpu_block_size
=
gpu_block_size
,
cpu_block_size
=
gpu_block_size
*
self
.
block_size
_factor
,
num_cpu_blocks
=
self
.
num_blocks
,
gpu_caches
=
kv_caches
,
)
...
...
vllm/v1/kv_offload/factory.py
View file @
cfaf4668
...
...
@@ -33,7 +33,7 @@ class OffloadingSpecFactory:
def
create_spec
(
cls
,
config
:
"VllmConfig"
,
kv_cache_config
:
"KVCacheConfig
| None
"
,
kv_cache_config
:
"KVCacheConfig"
,
)
->
OffloadingSpec
:
kv_transfer_config
=
config
.
kv_transfer_config
assert
kv_transfer_config
is
not
None
...
...
vllm/v1/kv_offload/spec.py
View file @
cfaf4668
...
...
@@ -21,9 +21,7 @@ logger = init_logger(__name__)
class
OffloadingSpec
(
ABC
):
"""Spec for an offloading connector"""
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
kv_cache_config
:
"KVCacheConfig | None"
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
kv_cache_config
:
"KVCacheConfig"
):
logger
.
warning
(
"Initializing OffloadingSpec. This API is experimental and "
"subject to change in the future as we iterate the design."
...
...
@@ -35,12 +33,34 @@ class OffloadingSpec(ABC):
assert
kv_transfer_config
is
not
None
self
.
extra_config
=
kv_transfer_config
.
kv_connector_extra_config
self
.
gpu_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
offloaded_block_size
=
int
(
self
.
extra_config
.
get
(
"block_size"
,
self
.
gpu_block_size
)
# block size used by vLLM for hashing request tokens for the sake
# of enabling prefix caching
self
.
hash_block_size
=
vllm_config
.
cache_config
.
block_size
# gpu block size per group
self
.
gpu_block_size
:
tuple
[
int
,
...]
=
tuple
(
kv_cache_group
.
kv_cache_spec
.
block_size
for
kv_cache_group
in
kv_cache_config
.
kv_cache_groups
)
assert
self
.
offloaded_block_size
%
self
.
gpu_block_size
==
0
for
block_size
in
self
.
gpu_block_size
:
assert
block_size
%
self
.
hash_block_size
==
0
# offloaded_block_size / gpu_block_size
self
.
block_size_factor
:
int
=
1
offloaded_block_size
=
self
.
extra_config
.
get
(
"block_size"
)
if
offloaded_block_size
is
not
None
:
offloaded_block_size_int
=
int
(
offloaded_block_size
)
gpu_block_sizes
=
set
(
self
.
gpu_block_size
)
assert
len
(
gpu_block_sizes
)
==
1
,
(
"If 'block_size' is specified in kv_connector_extra_config, "
"there must be at least one KV cache group, "
"and all groups must have the same block size."
)
gpu_block_size
=
gpu_block_sizes
.
pop
()
assert
offloaded_block_size_int
%
gpu_block_size
==
0
self
.
block_size_factor
=
offloaded_block_size_int
//
gpu_block_size
@
abstractmethod
def
get_manager
(
self
)
->
OffloadingManager
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment