Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4135232
Unverified
Commit
f4135232
authored
Jul 31, 2025
by
wxsm
Committed by
GitHub
Jul 30, 2025
Browse files
feat(distributed): add `get_required_kvcache_layout` class method to kv connector api (#20433)
Signed-off-by:
wxsm
<
wxsms@foxmail.com
>
parent
4904e53c
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
186 additions
and
28 deletions
+186
-28
tests/distributed/test_kvlayout.py
tests/distributed/test_kvlayout.py
+72
-0
vllm/distributed/kv_transfer/kv_connector/base.py
vllm/distributed/kv_transfer/kv_connector/base.py
+15
-1
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+21
-16
vllm/distributed/kv_transfer/kv_connector/utils.py
vllm/distributed/kv_transfer/kv_connector/utils.py
+10
-9
vllm/distributed/kv_transfer/kv_connector/v1/base.py
vllm/distributed/kv_transfer/kv_connector/v1/base.py
+14
-0
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
...istributed/kv_transfer/kv_connector/v1/multi_connector.py
+33
-0
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+21
-2
No files found.
tests/distributed/test_kvlayout.py
0 → 100644
View file @
f4135232
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.config
import
(
DeviceConfig
,
KVTransferConfig
,
ModelConfig
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.distributed.kv_transfer.kv_connector.utils
import
(
get_kv_connector_cache_layout
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
"test_expert_parallel"
)
def
test_get_kv_connector_cache_layout_without_kv_connector
():
vllm_config
=
VllmConfig
(
device_config
=
DeviceConfig
(
"cpu"
))
with
set_current_vllm_config
(
vllm_config
):
# Test with default settings
layout
=
get_kv_connector_cache_layout
()
assert
layout
==
"NHD"
def
test_get_kv_connector_cache_layout_with_lmcache_connector
():
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"LMCacheConnectorV1"
,
kv_role
=
"kv_both"
,
)
vllm_config
=
VllmConfig
(
device_config
=
DeviceConfig
(
"cpu"
),
kv_transfer_config
=
kv_transfer_config
)
with
set_current_vllm_config
(
vllm_config
):
# Test with default settings
layout
=
get_kv_connector_cache_layout
()
assert
layout
==
"NHD"
def
test_get_kv_connector_cache_layout_with_nixl_connector
():
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"NixlConnector"
,
kv_role
=
"kv_both"
,
)
model_config
=
ModelConfig
()
vllm_config
=
VllmConfig
(
device_config
=
DeviceConfig
(
"cpu"
),
model_config
=
model_config
,
kv_transfer_config
=
kv_transfer_config
)
with
set_current_vllm_config
(
vllm_config
):
# Test with default settings
layout
=
get_kv_connector_cache_layout
()
assert
layout
==
"HND"
def
test_get_kv_connector_cache_layout_with_multi_connector
():
kv_transfer_config
=
KVTransferConfig
(
kv_connector
=
"MultiConnector"
,
kv_role
=
"kv_both"
,
kv_connector_extra_config
=
{
"connectors"
:
[{
"kv_connector"
:
"SharedStorageConnector"
,
"kv_role"
:
"kv_both"
},
{
"kv_connector"
:
"NixlConnector"
,
"kv_role"
:
"kv_both"
}]
})
model_config
=
ModelConfig
()
vllm_config
=
VllmConfig
(
device_config
=
DeviceConfig
(
"cpu"
),
model_config
=
model_config
,
kv_transfer_config
=
kv_transfer_config
)
with
set_current_vllm_config
(
vllm_config
):
# Test with default settings
layout
=
get_kv_connector_cache_layout
()
assert
layout
==
"HND"
vllm/distributed/kv_transfer/kv_connector/base.py
View file @
f4135232
...
@@ -9,7 +9,7 @@ The class provides two primary abstract methods:
...
@@ -9,7 +9,7 @@ The class provides two primary abstract methods:
"""
"""
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
TYPE_CHECKING
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
import
torch
...
@@ -124,5 +124,19 @@ class KVConnectorBase(ABC):
...
@@ -124,5 +124,19 @@ class KVConnectorBase(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
"VllmConfig"
)
->
Optional
[
str
]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
return
None
KVConnectorBaseType
=
Union
[
KVConnectorBase
,
KVConnectorBase_V1
]
KVConnectorBaseType
=
Union
[
KVConnectorBase
,
KVConnectorBase_V1
]
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
f4135232
...
@@ -5,6 +5,7 @@ import importlib
...
@@ -5,6 +5,7 @@ import importlib
from
typing
import
TYPE_CHECKING
,
Callable
from
typing
import
TYPE_CHECKING
,
Callable
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
KVTransferConfig
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
)
KVConnectorRole
)
...
@@ -41,25 +42,15 @@ class KVConnectorFactory:
...
@@ -41,25 +42,15 @@ class KVConnectorFactory:
raise
ValueError
(
"Attempting to initialize a V0 Connector, "
raise
ValueError
(
"Attempting to initialize a V0 Connector, "
f
"but found
{
envs
.
VLLM_USE_V1
=
}
"
)
f
"but found
{
envs
.
VLLM_USE_V1
=
}
"
)
connector_name
=
config
.
kv_transfer_config
.
kv_connector
connector_cls
=
cls
.
get_connector_class
(
config
.
kv_transfer_config
)
if
connector_name
not
in
cls
.
_registry
:
raise
ValueError
(
f
"Unsupported connector type:
{
connector_name
}
"
)
connector_cls
=
cls
.
_registry
[
connector_name
]()
assert
issubclass
(
connector_cls
,
KVConnectorBase
)
assert
issubclass
(
connector_cls
,
KVConnectorBase
)
return
connector_cls
(
rank
,
local_rank
,
config
)
return
connector_cls
(
rank
,
local_rank
,
config
)
@
classmethod
@
classmethod
def
create_connector_v1
(
def
get_connector_class
(
cls
,
cls
,
kv_transfer_config
:
"KVTransferConfig"
config
:
"VllmConfig"
,
)
->
type
[
KVConnectorBaseType
]:
role
:
KVConnectorRole
,
"""Get the connector class by name."""
)
->
KVConnectorBase_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
f
"but found
{
envs
.
VLLM_USE_V1
=
}
"
)
kv_transfer_config
=
config
.
kv_transfer_config
connector_name
=
kv_transfer_config
.
kv_connector
connector_name
=
kv_transfer_config
.
kv_connector
if
connector_name
in
cls
.
_registry
:
if
connector_name
in
cls
.
_registry
:
connector_cls
=
cls
.
_registry
[
connector_name
]()
connector_cls
=
cls
.
_registry
[
connector_name
]()
...
@@ -70,9 +61,23 @@ class KVConnectorFactory:
...
@@ -70,9 +61,23 @@ class KVConnectorFactory:
f
"Unsupported connector type:
{
connector_name
}
"
)
f
"Unsupported connector type:
{
connector_name
}
"
)
connector_module
=
importlib
.
import_module
(
connector_module_path
)
connector_module
=
importlib
.
import_module
(
connector_module_path
)
connector_cls
=
getattr
(
connector_module
,
connector_name
)
connector_cls
=
getattr
(
connector_module
,
connector_name
)
return
connector_cls
@
classmethod
def
create_connector_v1
(
cls
,
config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
)
->
KVConnectorBase_V1
:
if
not
envs
.
VLLM_USE_V1
:
raise
ValueError
(
"Attempting to initialize a V1 Connector, "
f
"but found
{
envs
.
VLLM_USE_V1
=
}
"
)
kv_transfer_config
=
config
.
kv_transfer_config
connector_cls
=
cls
.
get_connector_class
(
kv_transfer_config
)
assert
issubclass
(
connector_cls
,
KVConnectorBase_V1
)
assert
issubclass
(
connector_cls
,
KVConnectorBase_V1
)
logger
.
info
(
"Creating v1 connector with name: %s and engine_id: %s"
,
logger
.
info
(
"Creating v1 connector with name: %s and engine_id: %s"
,
connector_name
,
kv_transfer_config
.
engine_id
)
connector_
cls
.
__
name
__
,
kv_transfer_config
.
engine_id
)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# Scheduler connector:
# - Co-locate with scheduler process
# - Co-locate with scheduler process
...
...
vllm/distributed/kv_transfer/kv_connector/utils.py
View file @
f4135232
...
@@ -13,6 +13,8 @@ import torch
...
@@ -13,6 +13,8 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.config
import
VllmConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
...
@@ -103,15 +105,14 @@ def get_kv_connector_cache_layout():
...
@@ -103,15 +105,14 @@ def get_kv_connector_cache_layout():
# used for faster transfer.
# used for faster transfer.
vllm_config
=
get_current_vllm_config
()
vllm_config
=
get_current_vllm_config
()
kv_config
=
vllm_config
.
kv_transfer_config
kv_config
=
vllm_config
.
kv_transfer_config
if
kv_config
is
not
None
and
vllm_config
.
model_config
is
None
:
if
kv_config
is
not
None
:
logger
.
warning_once
(
"Unable to detect current VLLM config. "
\
connector_cls
=
KVConnectorFactory
.
get_connector_class
(
kv_config
)
"Defaulting to NHD kv cache layout."
)
required_kvcache_layout
=
connector_cls
.
get_required_kvcache_layout
(
elif
kv_config
is
not
None
:
vllm_config
)
use_mla
=
vllm_config
.
model_config
.
use_mla
if
required_kvcache_layout
is
not
None
:
if
not
use_mla
and
kv_config
.
kv_connector
==
"NixlConnector"
:
return
required_kvcache_layout
logger
.
info_once
(
"NixlConnector detected. Setting KV cache "
\
logger
.
info_once
(
"Connectors do not specify a "
\
"layout to HND for better xfer performance."
)
"kv cache layout, defaulting to NHD."
)
return
"HND"
return
"NHD"
return
"NHD"
...
...
vllm/distributed/kv_transfer/kv_connector/v1/base.py
View file @
f4135232
...
@@ -299,3 +299,17 @@ class KVConnectorBase_V1(ABC):
...
@@ -299,3 +299,17 @@ class KVConnectorBase_V1(ABC):
returned by the engine.
returned by the engine.
"""
"""
return
False
,
None
return
False
,
None
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
"VllmConfig"
)
->
Optional
[
str
]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
return
None
vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
View file @
f4135232
...
@@ -202,3 +202,36 @@ class MultiConnector(KVConnectorBase_V1):
...
@@ -202,3 +202,36 @@ class MultiConnector(KVConnectorBase_V1):
self
.
_requests_to_connector
.
pop
(
request
.
request_id
,
None
)
self
.
_requests_to_connector
.
pop
(
request
.
request_id
,
None
)
return
async_saves
>
0
,
kv_txfer_params
return
async_saves
>
0
,
kv_txfer_params
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
"VllmConfig"
)
->
Optional
[
str
]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
ktcs
=
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
.
get
(
"connectors"
)
assert
ktcs
is
not
None
layouts
:
set
[
str
]
=
set
()
temp_vllm_config
=
copy
.
copy
(
vllm_config
)
for
ktc
in
ktcs
:
kv_transfer_config
=
KVTransferConfig
(
**
ktc
)
temp_vllm_config
.
kv_transfer_config
=
kv_transfer_config
required_kvcache_layout
=
KVConnectorFactory
.
get_connector_class
(
kv_transfer_config
).
get_required_kvcache_layout
(
temp_vllm_config
)
if
required_kvcache_layout
is
not
None
:
layouts
.
add
(
required_kvcache_layout
)
if
len
(
layouts
)
>
1
:
raise
ValueError
(
f
"KV cache layout mismatch: "
f
"found
{
len
(
layouts
)
}
different layouts "
f
"(
{
', '
.
join
(
layouts
)
}
)."
f
"All connectors must use the same layout."
)
return
next
(
iter
(
layouts
),
None
)
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
f4135232
...
@@ -133,6 +133,25 @@ class NixlConnector(KVConnectorBase_V1):
...
@@ -133,6 +133,25 @@ class NixlConnector(KVConnectorBase_V1):
self
.
connector_worker
=
NixlConnectorWorker
(
self
.
connector_worker
=
NixlConnectorWorker
(
vllm_config
,
self
.
engine_id
)
vllm_config
,
self
.
engine_id
)
############################################################
# Class Methods
############################################################
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
VllmConfig
):
if
vllm_config
.
model_config
is
None
:
logger
.
warning_once
(
"Unable to detect current VLLM config. "
"Fallback to default kv cache layout."
)
return
None
use_mla
=
vllm_config
.
model_config
.
use_mla
if
use_mla
:
# return None when we have mla
# as the layout should not matter in that case,
# which fallback to the default behavior.
return
None
logger
.
info_once
(
"NixlConnector setting KV cache "
"layout to HND for better xfer performance."
)
return
"HND"
############################################################
############################################################
# Scheduler Side Methods
# Scheduler Side Methods
############################################################
############################################################
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment