Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
58279c60
Unverified
Commit
58279c60
authored
Nov 04, 2025
by
Mark McLoughlin
Committed by
GitHub
Nov 03, 2025
Browse files
[KV Connector] Make KVCacheConfig an explicit constructor argument (#27887)
Signed-off-by:
Mark McLoughlin
<
markmc@redhat.com
>
parent
2f84ae1f
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
410 additions
and
43 deletions
+410
-43
tests/v1/kv_connector/unit/test_backwards_compatibility.py
tests/v1/kv_connector/unit/test_backwards_compatibility.py
+275
-0
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+1
-1
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+33
-8
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+15
-1
vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
...ted/kv_transfer/kv_connector/v1/decode_bench_connector.py
+9
-3
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+10
-2
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+11
-3
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+10
-2
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
...buted/kv_transfer/kv_connector/v1/offloading_connector.py
+8
-2
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
...ted/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
+13
-3
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
...d/kv_transfer/kv_connector/v1/shared_storage_connector.py
+13
-3
vllm/distributed/kv_transfer/kv_transfer_state.py
vllm/distributed/kv_transfer/kv_transfer_state.py
+8
-3
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+3
-9
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+1
-3
No files found.
tests/v1/kv_connector/unit/test_backwards_compatibility.py
0 → 100644
View file @
58279c60
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for backwards compatibility with external KV connector implementations.
This test ensures that external connectors (loaded via kv_connector_module_path)
implemented with the old signature continue to work:
- Old signature: __init__(self, vllm_config, role)
- New signature: __init__(self, vllm_config, role, kv_cache_config)
"""
from
typing
import
TYPE_CHECKING
from
unittest.mock
import
patch
import
pytest
from
vllm.distributed.kv_transfer.kv_connector.factory
import
KVConnectorFactory
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
,
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
.utils
import
create_scheduler
,
create_vllm_config
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
class
OldStyleTestConnector
(
KVConnectorBase_V1
):
"""
Test connector using the old signature with 2 required arguments.
This simulates external connectors that haven't been updated yet.
"""
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
# Old-style call to super().__init__ with only 2 arguments
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
)
->
tuple
[
int
|
None
,
bool
]:
return
0
,
False
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
,
):
pass
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
):
return
None
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
pass
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
pass
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
,
attn_metadata
:
"AttentionMetadata"
,
**
kwargs
,
)
->
None
:
pass
def
wait_for_save
(
self
):
pass
class
NewStyleTestConnector
(
KVConnectorBase_V1
):
"""
Test connector using the new signature with 3 required arguments.
"""
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
"KVCacheConfig"
,
):
# New-style call to super().__init__ with all 3 arguments
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
,
kv_cache_config
=
kv_cache_config
)
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
)
->
tuple
[
int
|
None
,
bool
]:
return
0
,
False
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
,
):
pass
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
):
return
None
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
pass
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
pass
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
,
attn_metadata
:
"AttentionMetadata"
,
**
kwargs
,
)
->
None
:
pass
def
wait_for_save
(
self
):
pass
@
pytest
.
mark
.
parametrize
(
"role"
,
[
KVConnectorRole
.
SCHEDULER
,
KVConnectorRole
.
WORKER
])
def
test_external_old_signature_factory_instantiation
(
role
):
"""
Test that external connectors with old signature (2 required args) loaded
via kv_connector_module_path are correctly instantiated with backwards
compatibility support.
"""
vllm_config
=
create_vllm_config
()
vllm_config
.
kv_transfer_config
.
kv_connector
=
"OldStyleTestConnector"
vllm_config
.
kv_transfer_config
.
kv_connector_module_path
=
(
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler
=
create_scheduler
(
vllm_config
)
kv_cache_config
=
scheduler
.
kv_cache_config
connector
=
KVConnectorFactory
.
create_connector
(
vllm_config
,
role
,
kv_cache_config
)
assert
connector
is
not
None
assert
isinstance
(
connector
,
OldStyleTestConnector
)
assert
connector
.
role
==
role
assert
connector
.
_kv_cache_config
is
None
@
pytest
.
mark
.
parametrize
(
"role"
,
[
KVConnectorRole
.
SCHEDULER
,
KVConnectorRole
.
WORKER
])
def
test_external_new_signature_factory_instantiation
(
role
):
"""
Test that external connectors with new signature (3 required args) loaded
via kv_connector_module_path are correctly instantiated.
"""
vllm_config
=
create_vllm_config
()
vllm_config
.
kv_transfer_config
.
kv_connector
=
"NewStyleTestConnector"
vllm_config
.
kv_transfer_config
.
kv_connector_module_path
=
(
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler
=
create_scheduler
(
vllm_config
)
kv_cache_config
=
scheduler
.
kv_cache_config
connector
=
KVConnectorFactory
.
create_connector
(
vllm_config
,
role
,
kv_cache_config
)
assert
connector
is
not
None
assert
isinstance
(
connector
,
NewStyleTestConnector
)
assert
connector
.
role
==
role
assert
connector
.
_kv_cache_config
is
not
None
assert
connector
.
_kv_cache_config
==
kv_cache_config
@
pytest
.
mark
.
parametrize
(
"role"
,
[
KVConnectorRole
.
SCHEDULER
,
KVConnectorRole
.
WORKER
])
def
test_old_signature_super_init
(
role
):
"""
Test that old-style connectors can call super().__init__() without
kv_cache_config parameter.
"""
vllm_config
=
create_vllm_config
()
connector
=
OldStyleTestConnector
(
vllm_config
,
role
)
assert
connector
is
not
None
assert
connector
.
role
==
role
assert
connector
.
_kv_cache_config
is
None
def
test_old_signature_super_init_with_kwargs
():
"""
Test that old-style connectors can call super().__init__() with keyword
arguments in different orders.
"""
vllm_config
=
create_vllm_config
()
# Test with vllm_config= and role= kwargs
connector1
=
OldStyleTestConnector
(
vllm_config
=
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
assert
connector1
is
not
None
assert
connector1
.
_kv_cache_config
is
None
# Test with role= and vllm_config= in reversed order
connector2
=
OldStyleTestConnector
(
role
=
KVConnectorRole
.
WORKER
,
vllm_config
=
vllm_config
)
assert
connector2
is
not
None
assert
connector2
.
_kv_cache_config
is
None
def
test_internal_connector_uses_new_signature
():
"""
Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config.
"""
from
vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector
import
(
SharedStorageConnector
,
)
vllm_config
=
create_vllm_config
()
vllm_config
.
kv_transfer_config
.
kv_connector
=
"SharedStorageConnector"
scheduler
=
create_scheduler
(
vllm_config
)
kv_cache_config
=
scheduler
.
kv_cache_config
connector
=
KVConnectorFactory
.
create_connector
(
vllm_config
,
KVConnectorRole
.
SCHEDULER
,
kv_cache_config
)
assert
connector
is
not
None
assert
isinstance
(
connector
,
SharedStorageConnector
)
assert
connector
.
_kv_cache_config
is
not
None
assert
connector
.
_kv_cache_config
==
kv_cache_config
def
test_signature_detection_with_mocking
():
"""
Test that the factory correctly applies compat_sig flag returned from
_get_connector_class_with_compat.
"""
vllm_config
=
create_vllm_config
()
scheduler
=
create_scheduler
(
vllm_config
)
kv_cache_config
=
scheduler
.
kv_cache_config
# Mock _get_connector_class_with_compat to return old-style connector
with
patch
.
object
(
KVConnectorFactory
,
"_get_connector_class_with_compat"
,
return_value
=
(
OldStyleTestConnector
,
True
),
):
old_connector
=
KVConnectorFactory
.
create_connector
(
vllm_config
,
KVConnectorRole
.
SCHEDULER
,
kv_cache_config
)
assert
old_connector
is
not
None
assert
isinstance
(
old_connector
,
OldStyleTestConnector
)
assert
old_connector
.
_kv_cache_config
is
None
# Mock _get_connector_class_with_compat to return new-style connector
with
patch
.
object
(
KVConnectorFactory
,
"_get_connector_class_with_compat"
,
return_value
=
(
NewStyleTestConnector
,
False
),
):
new_connector
=
KVConnectorFactory
.
create_connector
(
vllm_config
,
KVConnectorRole
.
SCHEDULER
,
kv_cache_config
)
assert
new_connector
is
not
None
assert
isinstance
(
new_connector
,
NewStyleTestConnector
)
assert
new_connector
.
_kv_cache_config
is
not
None
assert
new_connector
.
_kv_cache_config
==
kv_cache_config
tests/v1/kv_connector/unit/utils.py
View file @
58279c60
...
...
@@ -254,7 +254,7 @@ def create_model_runner_output(
class
TestSharedStorageConnector
(
SharedStorageConnector
):
def
__init__
(
self
,
config
:
VllmConfig
,
role
):
def
__init__
(
self
,
config
:
VllmConfig
,
role
,
kv_cache_config
):
self
.
name
=
config
.
kv_transfer_config
.
kv_connector_extra_config
[
"name"
]
self
.
_connector
=
SharedStorageConnector
(
config
,
role
)
self
.
call_record
:
dict
[
str
,
int
]
=
defaultdict
(
int
)
...
...
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
58279c60
...
...
@@ -3,10 +3,9 @@
import
importlib
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
cast
from
typing
import
TYPE_CHECKING
,
Optional
,
cast
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
(
KVConnectorBase
,
KVConnectorBaseType
,
...
...
@@ -16,9 +15,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
supports_hma
,
)
from
vllm.logger
import
init_logger
from
vllm.utils.func_utils
import
supports_kw
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.config.kv_transfer
import
KVTransferConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
logger
=
init_logger
(
__name__
)
...
...
@@ -41,8 +43,9 @@ class KVConnectorFactory:
@
classmethod
def
create_connector
(
cls
,
config
:
VllmConfig
,
config
:
"
VllmConfig
"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
)
->
KVConnectorBase
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
...
...
@@ -53,7 +56,9 @@ class KVConnectorFactory:
kv_transfer_config
=
config
.
kv_transfer_config
if
kv_transfer_config
is
None
:
raise
ValueError
(
"kv_transfer_config must be set to create a connector"
)
connector_cls
=
cls
.
get_connector_class
(
kv_transfer_config
)
connector_cls
,
compat_sig
=
cls
.
_get_connector_class_with_compat
(
kv_transfer_config
)
# check if the connector supports HMA
hma_enabled
=
not
config
.
scheduler_config
.
disable_hybrid_kv_cache_manager
...
...
@@ -76,7 +81,12 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
if
compat_sig
:
# Old signature: __init__(self, vllm_config, role)
return
connector_cls
(
config
,
role
)
else
:
# New signature: __init__(self, vllm_config, role, kv_cache_config)
return
connector_cls
(
config
,
role
,
kv_cache_config
)
@
classmethod
def
get_connector_class_by_name
(
...
...
@@ -97,13 +107,13 @@ class KVConnectorFactory:
return
cls
.
_registry
[
connector_name
]()
@
classmethod
def
get_connector_class
(
def
_
get_connector_class
_with_compat
(
cls
,
kv_transfer_config
:
"KVTransferConfig"
)
->
type
[
KVConnectorBaseType
]:
"""Get the connector class by name."""
)
->
tuple
[
type
[
KVConnectorBaseType
],
bool
]:
connector_name
=
kv_transfer_config
.
kv_connector
if
connector_name
is
None
:
raise
ValueError
(
"Connector name is not set in KVTransferConfig"
)
compat_sig
=
False
if
connector_name
in
cls
.
_registry
:
connector_cls
=
cls
.
_registry
[
connector_name
]()
else
:
...
...
@@ -118,6 +128,21 @@ class KVConnectorFactory:
f
"Class
{
connector_name
}
not found in
{
connector_module_path
}
"
)
from
e
connector_cls
=
cast
(
type
[
KVConnectorBaseType
],
connector_cls
)
if
not
supports_kw
(
connector_cls
,
"kv_cache_config"
):
compat_sig
=
True
logger
.
warning
(
"Connector %s uses deprecated signature with 2 required arguments. "
"Please update to include kv_cache_config as the second argument."
,
connector_cls
.
__name__
,
)
return
connector_cls
,
compat_sig
@
classmethod
def
get_connector_class
(
cls
,
kv_transfer_config
:
"KVTransferConfig"
)
->
type
[
KVConnectorBaseType
]:
"""Get the connector class by name."""
connector_cls
,
_
=
cls
.
_get_connector_class_with_compat
(
kv_transfer_config
)
return
connector_cls
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
58279c60
...
...
@@ -58,6 +58,7 @@ if TYPE_CHECKING:
)
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
...
...
@@ -141,7 +142,12 @@ class KVConnectorMetadata(ABC): # noqa: B024
class
KVConnectorBase_V1
(
ABC
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
):
logger
.
warning
(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design."
...
...
@@ -152,6 +158,14 @@ class KVConnectorBase_V1(ABC):
self
.
_kv_transfer_config
=
vllm_config
.
kv_transfer_config
else
:
raise
ValueError
(
"kv_transfer_config must be set for KVConnectorBase_V1"
)
self
.
_kv_cache_config
=
kv_cache_config
if
self
.
_kv_cache_config
is
None
:
logger
.
warning
(
"KVConnectorBase_V1 initialized without kv_cache_config. "
"This is deprecated - please update your connector to accept "
"kv_cache_config as the third constructor argument and pass it "
"to super().__init__()."
)
self
.
_role
=
role
@
property
...
...
vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py
View file @
58279c60
...
...
@@ -32,7 +32,7 @@ Usage:
"""
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
...
...
@@ -50,6 +50,7 @@ if TYPE_CHECKING:
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
...
...
@@ -79,8 +80,13 @@ class DecodeBenchConnector(KVConnectorBase_V1):
testing of the decoder with larger input sequence lengths.
"""
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
,
role
)
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
self
.
connector_scheduler
:
DecodeBenchConnectorScheduler
|
None
=
None
self
.
connector_worker
:
DecodeBenchConnectorWorker
|
None
=
None
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
View file @
58279c60
...
...
@@ -20,14 +20,22 @@ if TYPE_CHECKING:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
class
LMCacheConnectorV1
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
"KVCacheConfig"
,
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
,
kv_cache_config
=
kv_cache_config
)
assert
vllm_config
.
kv_transfer_config
is
not
None
use_native
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"use_native"
,
False
...
...
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
58279c60
...
...
@@ -31,6 +31,7 @@ if TYPE_CHECKING:
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
...
...
@@ -109,15 +110,22 @@ class MultiConnector(KVConnectorBase_V1):
- Save to all connectors.
"""
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
"KVCacheConfig"
,
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
,
kv_cache_config
=
kv_cache_config
)
self
.
_connectors
:
list
[
KVConnectorBase_V1
]
=
[]
self
.
_ktc_kv_transfer_config
=
[]
for
connector_cls
,
temp_config
in
self
.
_get_connector_classes_and_configs
(
vllm_config
):
self
.
_connectors
.
append
(
connector_cls
(
temp_config
,
role
))
self
.
_connectors
.
append
(
connector_cls
(
temp_config
,
role
,
kv_cache_config
))
self
.
_ktc_kv_transfer_config
.
append
(
temp_config
.
kv_transfer_config
)
# A mapping from request id to the index of the connector chosen to
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
58279c60
...
...
@@ -13,7 +13,7 @@ from collections import defaultdict
from
collections.abc
import
Iterator
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
msgspec
import
numpy
as
np
...
...
@@ -52,6 +52,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
Transfer
=
tuple
[
int
,
float
]
# (xfer_handle, start_time)
...
...
@@ -150,7 +151,14 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class
NixlConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
assert
vllm_config
.
kv_transfer_config
is
not
None
assert
vllm_config
.
kv_transfer_config
.
engine_id
is
not
None
self
.
engine_id
:
EngineId
=
vllm_config
.
kv_transfer_config
.
engine_id
...
...
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py
View file @
58279c60
...
...
@@ -21,6 +21,7 @@ from vllm.logger import init_logger
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_utils
import
BlockHash
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_offload.abstract
import
OffloadingManager
from
vllm.v1.kv_offload.factory
import
OffloadingSpecFactory
from
vllm.v1.kv_offload.mediums
import
GPULoadStoreSpec
...
...
@@ -41,8 +42,13 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class
OffloadingConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
,
role
)
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
role
:
KVConnectorRole
,
kv_cache_config
:
KVCacheConfig
|
None
=
None
,
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
spec
=
OffloadingSpecFactory
.
create_spec
(
vllm_config
)
...
...
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py
View file @
58279c60
...
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
regex
as
re
import
torch
...
...
@@ -25,6 +25,7 @@ if TYPE_CHECKING:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
...
...
@@ -71,8 +72,17 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
class
P2pNcclConnector
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
,
kv_cache_config
=
kv_cache_config
,
)
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_requests_need_load
:
dict
[
str
,
Any
]
=
{}
self
.
is_producer
=
self
.
_kv_transfer_config
.
is_kv_producer
...
...
vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py
View file @
58279c60
...
...
@@ -3,7 +3,7 @@
import
hashlib
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
safetensors
import
torch
...
...
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
...
...
@@ -86,8 +87,17 @@ class SharedStorageConnector(KVConnectorBase_V1):
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
,
kv_cache_config
=
kv_cache_config
,
)
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
self
.
_requests_need_load
:
dict
[
str
,
Request
]
=
{}
self
.
_storage_path
=
self
.
_kv_transfer_config
.
get_from_extra_config
(
...
...
vllm/distributed/kv_transfer/kv_transfer_state.py
View file @
58279c60
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm
import
envs
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
...
...
@@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
_KV_CONNECTOR_AGENT
:
KVConnectorBaseType
|
None
=
None
...
...
@@ -48,7 +49,9 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
return
isinstance
(
connector
,
KVConnectorBase_V1
)
def
ensure_kv_transfer_initialized
(
vllm_config
:
"VllmConfig"
)
->
None
:
def
ensure_kv_transfer_initialized
(
vllm_config
:
"VllmConfig"
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
)
->
None
:
"""
Initialize KV cache transfer parallel group.
"""
...
...
@@ -64,7 +67,9 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
):
if
envs
.
VLLM_USE_V1
:
_KV_CONNECTOR_AGENT
=
KVConnectorFactory
.
create_connector
(
config
=
vllm_config
,
role
=
KVConnectorRole
.
WORKER
config
=
vllm_config
,
role
=
KVConnectorRole
.
WORKER
,
kv_cache_config
=
kv_cache_config
,
)
else
:
raise
ValueError
(
"V0 is no longer supported"
)
...
...
vllm/v1/core/sched/scheduler.py
View file @
58279c60
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
import
itertools
import
time
from
collections
import
defaultdict
...
...
@@ -92,15 +91,10 @@ class Scheduler(SchedulerInterface):
assert
not
self
.
is_encoder_decoder
,
(
"Encoder-decoder models are not currently supported with KV connectors"
)
connector_vllm_config
=
copy
.
copy
(
self
.
vllm_config
)
# We're dynamically inserting a kv_cache_config variable into the
# connector_vllm_config. This is distinct from the cache_config
# that is already in there.
connector_vllm_config
.
kv_cache_config
=
copy
.
copy
(
kv_cache_config
)
# type: ignore[attr-defined]
self
.
connector
=
KVConnectorFactory
.
create_connector
(
config
=
connector_vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
,
kv_cache_config
=
self
.
kv_cache_config
,
)
if
self
.
log_stats
:
self
.
connector_prefix_cache_stats
=
PrefixCacheStats
()
...
...
vllm/v1/worker/gpu_worker.py
View file @
58279c60
...
...
@@ -380,9 +380,7 @@ class Worker(WorkerBase):
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
connector_vllm_config
=
copy
.
copy
(
self
.
vllm_config
)
connector_vllm_config
.
kv_cache_config
=
copy
.
copy
(
kv_cache_config
)
ensure_kv_transfer_initialized
(
connector_vllm_config
)
ensure_kv_transfer_initialized
(
self
.
vllm_config
,
kv_cache_config
)
if
self
.
vllm_config
.
model_config
.
enable_sleep_mode
:
from
vllm.device_allocator.cumem
import
CuMemAllocator
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment