Unverified Commit a13d8c03 authored by Yashwant Bezawada's avatar Yashwant Bezawada Committed by GitHub
Browse files

[KVConnector] Auto-downgrade to PIECEWISE cudagraph mode for layerwise async ops (#31057)


Signed-off-by: default avatarYashwant Bezawada <yashwant_b@me.com>
parent 9433acb8
......@@ -925,6 +925,33 @@ class VllmConfig:
CUDAGraphMode.FULL_DECODE_ONLY
)
# Check if KV connector requires PIECEWISE mode for CUDA graphs
if (
self.kv_transfer_config is not None
and self.kv_transfer_config.is_kv_transfer_instance
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
):
# Lazy import to avoid circular dependencies
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory,
)
connector_cls = KVConnectorFactory.get_connector_class(
self.kv_transfer_config
)
if connector_cls.requires_piecewise_for_cudagraph(
self.kv_transfer_config.kv_connector_extra_config
):
logger.warning_once(
"KV connector %s requires PIECEWISE CUDA graph mode "
"due to layerwise async operations that cannot be "
"captured in CUDA graphs. "
"Overriding cudagraph_mode from %s to PIECEWISE.",
connector_cls.__name__,
self.compilation_config.cudagraph_mode.name,
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
logger.info("Cudagraph is disabled under eager mode")
......
......@@ -543,6 +543,28 @@ class KVConnectorBase_V1(ABC):
)
return None
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
Check if this connector requires PIECEWISE CUDA graph mode.
Connectors that use asynchronous layer-by-layer operations
(wait_for_layer_load/save_kv_layer) should override this method
to return True when those operations are enabled. These operations
cannot be captured in CUDA graphs and will be skipped during replay,
causing data races. PIECEWISE mode allows Python code to execute
between graph pieces, ensuring proper synchronization.
Args:
extra_config: The kv_connector_extra_config dict from
KVTransferConfig.
Returns:
True if this connector requires PIECEWISE CUDA graph mode,
False otherwise.
"""
return False
def get_finished_count(self) -> int | None:
"""
Get the count of requests expected to complete send/receive operations
......
......@@ -70,6 +70,16 @@ class LMCacheKVEvents(KVConnectorKVEvents):
class LMCacheConnectorV1(KVConnectorBase_V1):
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
LMCache requires PIECEWISE CUDA graph mode when layerwise
operations are enabled. The wait_for_layer_load and save_kv_layer
methods perform actual async synchronization that cannot be
captured in CUDA graphs.
"""
return extra_config.get("use_layerwise", False)
def __init__(
self,
vllm_config: "VllmConfig",
......
......@@ -112,6 +112,21 @@ class MultiConnector(KVConnectorBase_V1):
- Save to all connectors.
"""
@classmethod
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
"""
MultiConnector requires PIECEWISE CUDA graph mode if any of its
child connectors require it.
"""
connectors_config = extra_config.get("connectors", [])
for conn_config in connectors_config:
temp_ktc = KVTransferConfig(**conn_config)
connector_cls = KVConnectorFactory.get_connector_class(temp_ktc)
child_extra_config = conn_config.get("kv_connector_extra_config", {})
if connector_cls.requires_piecewise_for_cudagraph(child_extra_config):
return True
return False
def __init__(
self,
vllm_config: "VllmConfig",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment