Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b89443b8
Unverified
Commit
b89443b8
authored
Jan 07, 2026
by
Kfir Toledo
Committed by
GitHub
Jan 07, 2026
Browse files
[KVConnector]: Enable Cross-layers KV cache layout for MultiConnector (#30761)
Signed-off-by:
Kfir Toledo
<
kfir.toledo@ibm.com
>
parent
1d9e9ae8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
11 deletions
+71
-11
tests/v1/kv_connector/unit/test_multi_connector.py
tests/v1/kv_connector/unit/test_multi_connector.py
+45
-0
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+8
-8
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+14
-1
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+4
-2
No files found.
tests/v1/kv_connector/unit/test_multi_connector.py
View file @
b89443b8
...
@@ -49,6 +49,33 @@ class MockConnector(KVConnectorBase_V1):
...
@@ -49,6 +49,33 @@ class MockConnector(KVConnectorBase_V1):
)
->
KVConnectorStats
|
None
:
)
->
KVConnectorStats
|
None
:
return
MockConnectorStats
(
data
=
data
)
if
data
is
not
None
else
None
return
MockConnectorStats
(
data
=
data
)
if
data
is
not
None
else
None
def
start_load_kv
(
self
,
forward_context
,
**
kwargs
):
pass
def
wait_for_layer_load
(
self
,
layer_name
):
pass
def
save_kv_layer
(
self
,
layer_name
,
kv_layer
,
attn_metadata
,
**
kwargs
):
pass
def
wait_for_save
(
self
):
pass
def
build_connector_meta
(
self
,
scheduler_output
):
return
None
def
get_num_new_matched_tokens
(
self
,
request
,
num_computed_tokens
):
return
(
0
,
False
)
def
update_state_after_alloc
(
self
,
request
,
blocks
,
num_tokens
)
->
None
:
pass
class
MockCrossLayerConnector
(
MockConnector
):
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
return
True
# Register the mock connector
# Register the mock connector
KVConnectorFactory
.
register_connector
(
"MockConnector"
,
__name__
,
MockConnector
.
__name__
)
KVConnectorFactory
.
register_connector
(
"MockConnector"
,
__name__
,
MockConnector
.
__name__
)
...
@@ -601,3 +628,21 @@ class TestMultiConnectorStats:
...
@@ -601,3 +628,21 @@ class TestMultiConnectorStats:
# One non-empty
# One non-empty
stats
.
data
[
"NixlConnector"
].
data
[
"transfer_duration"
].
append
(
1.0
)
stats
.
data
[
"NixlConnector"
].
data
[
"transfer_duration"
].
append
(
1.0
)
assert
not
stats
.
is_empty
()
assert
not
stats
.
is_empty
()
class
TestMultiConnectorPreferCrossLayerBlocks
:
def
test_all_connectors_prefer_cross_layer_blocks
(
self
):
mc
=
MultiConnector
.
__new__
(
MultiConnector
)
mc
.
_connectors
=
[
MockCrossLayerConnector
.
__new__
(
MockCrossLayerConnector
),
MockCrossLayerConnector
.
__new__
(
MockCrossLayerConnector
),
]
assert
mc
.
prefer_cross_layer_blocks
is
True
def
test_mixed_connectors_do_not_prefer_cross_layer_blocks
(
self
):
mc
=
MultiConnector
.
__new__
(
MultiConnector
)
mc
.
_connectors
=
[
MockCrossLayerConnector
.
__new__
(
MockCrossLayerConnector
),
MockConnector
.
__new__
(
MockConnector
),
# default False
]
assert
mc
.
prefer_cross_layer_blocks
is
False
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
b89443b8
...
@@ -38,7 +38,7 @@ The class provides the following primitives:
...
@@ -38,7 +38,7 @@ The class provides the following primitives:
import
enum
import
enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
Callable
,
Iterable
from
collections.abc
import
Callable
,
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Literal
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
import
torch
import
torch
...
@@ -144,15 +144,15 @@ class KVConnectorMetadata(ABC): # noqa: B024
...
@@ -144,15 +144,15 @@ class KVConnectorMetadata(ABC): # noqa: B024
class
KVConnectorBase_V1
(
ABC
):
class
KVConnectorBase_V1
(
ABC
):
"""
"""
Base class for KV connectors.
Base class for KV connectors.
Attributes:
prefer_cross_layer_blocks (bool): Indicates whether this connector
prefers KV blocks that hold KV data for all layers (for speeding
up KV data transfers).
Defaults to False.
"""
"""
prefer_cross_layer_blocks
:
ClassVar
[
bool
]
=
False
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
"""
Indicates whether this connector prefers KV blocks that hold KV data for all
layers, which can speed up KV data transfers. Defaults to False.
"""
return
False
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
b89443b8
...
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
...
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
import
torch
import
torch
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.abstract
import
AttentionBackend
,
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
...
@@ -138,6 +138,12 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -138,6 +138,12 @@ class MultiConnector(KVConnectorBase_V1):
# Propagated from scheduler to worker side via the connector metadata.
# Propagated from scheduler to worker side via the connector metadata.
self
.
_extra_async_saves
:
dict
[
str
,
int
]
=
{}
self
.
_extra_async_saves
:
dict
[
str
,
int
]
=
{}
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
if
not
self
.
_connectors
:
return
False
return
all
(
c
.
prefer_cross_layer_blocks
for
c
in
self
.
_connectors
)
@
classmethod
@
classmethod
def
_get_connector_classes_and_configs
(
def
_get_connector_classes_and_configs
(
cls
,
vllm_config
:
"VllmConfig"
cls
,
vllm_config
:
"VllmConfig"
...
@@ -164,6 +170,13 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -164,6 +170,13 @@ class MultiConnector(KVConnectorBase_V1):
)
)
return
ret
return
ret
def
register_cross_layers_kv_cache
(
self
,
kv_cache
:
torch
.
Tensor
,
attn_backend
:
type
[
AttentionBackend
]
):
# Register on all connectors
for
c
in
self
.
_connectors
:
c
.
register_cross_layers_kv_cache
(
kv_cache
,
attn_backend
)
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
for
c
in
self
.
_connectors
:
for
c
in
self
.
_connectors
:
c
.
register_kv_caches
(
kv_caches
)
c
.
register_kv_caches
(
kv_caches
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
b89443b8
...
@@ -4,7 +4,7 @@ from collections import defaultdict
...
@@ -4,7 +4,7 @@ from collections import defaultdict
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
islice
from
itertools
import
islice
from
typing
import
Any
,
ClassVar
from
typing
import
Any
import
torch
import
torch
...
@@ -44,7 +44,9 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
...
@@ -44,7 +44,9 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class
OffloadingConnector
(
KVConnectorBase_V1
):
class
OffloadingConnector
(
KVConnectorBase_V1
):
prefer_cross_layer_blocks
:
ClassVar
[
bool
]
=
True
@
property
def
prefer_cross_layer_blocks
(
self
)
->
bool
:
return
True
def
__init__
(
def
__init__
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment