Unverified Commit bc290e7c authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: allow user to adjust kv_transfer_config (#2517)

parent e31c8790
...@@ -163,6 +163,7 @@ vLLM workers are configured through command-line arguments. Key parameters inclu ...@@ -163,6 +163,7 @@ vLLM workers are configured through command-line arguments. Key parameters inclu
- `--model`: Model to serve (e.g., `Qwen/Qwen3-0.6B`) - `--model`: Model to serve (e.g., `Qwen/Qwen3-0.6B`)
- `--is-prefill-worker`: Enable prefill-only mode for disaggregated serving - `--is-prefill-worker`: Enable prefill-only mode for disaggregated serving
- `--metrics-endpoint-port`: Port for publishing KV metrics to Dynamo - `--metrics-endpoint-port`: Port for publishing KV metrics to Dynamo
- `--connector`: Specify which kv_transfer_config you want vllm to use `[nixl, lmcache, kvbm, none]`. This is a helper flag which overwrites the engines KVTransferConfig.
See `args.py` for the full list of configuration options and their defaults. See `args.py` for the full list of configuration options and their defaults.
......
...@@ -9,4 +9,4 @@ python -m dynamo.frontend & ...@@ -9,4 +9,4 @@ python -m dynamo.frontend &
# run worker # run worker
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
python -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --no-enable-prefix-caching python -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none
...@@ -9,6 +9,6 @@ python -m dynamo.frontend --router-mode kv & ...@@ -9,6 +9,6 @@ python -m dynamo.frontend --router-mode kv &
# run workers # run workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager & CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none &
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none
...@@ -21,4 +21,5 @@ LMCACHE_MAX_LOCAL_CPU_SIZE=20 \ ...@@ -21,4 +21,5 @@ LMCACHE_MAX_LOCAL_CPU_SIZE=20 \
CUDA_VISIBLE_DEVICES=1 \ CUDA_VISIBLE_DEVICES=1 \
python3 -m dynamo.vllm \ python3 -m dynamo.vllm \
--model Qwen/Qwen3-0.6B \ --model Qwen/Qwen3-0.6B \
--is-prefill-worker --is-prefill-worker \
\ No newline at end of file --connector lmcache nixl
...@@ -30,6 +30,8 @@ logger = logging.getLogger(__name__) ...@@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B" DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
VALID_CONNECTORS = {"nixl", "lmcache", "kvbm", "null", "none"}
# Global LMCache configuration - initialize once on module import # Global LMCache configuration - initialize once on module import
ENABLE_LMCACHE = os.getenv("ENABLE_LMCACHE", "0").lower() in ("1", "true", "yes") ENABLE_LMCACHE = os.getenv("ENABLE_LMCACHE", "0").lower() in ("1", "true", "yes")
...@@ -44,7 +46,6 @@ class Config: ...@@ -44,7 +46,6 @@ class Config:
is_prefill_worker: bool is_prefill_worker: bool
migration_limit: int = 0 migration_limit: int = 0
kv_port: Optional[int] = None kv_port: Optional[int] = None
side_channel_port: Optional[int] = None
port_range: DynamoPortRange port_range: DynamoPortRange
# mirror vLLM # mirror vLLM
...@@ -54,6 +55,9 @@ class Config: ...@@ -54,6 +55,9 @@ class Config:
# rest vLLM args # rest vLLM args
engine_args: AsyncEngineArgs engine_args: AsyncEngineArgs
# Connector list from CLI
connector_list: Optional[list] = None
def parse_args() -> Config: def parse_args() -> Config:
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
...@@ -91,6 +95,13 @@ def parse_args() -> Config: ...@@ -91,6 +95,13 @@ def parse_args() -> Config:
default=DEFAULT_DYNAMO_PORT_MAX, default=DEFAULT_DYNAMO_PORT_MAX,
help=f"Maximum port number for Dynamo services (default: {DEFAULT_DYNAMO_PORT_MAX}). Must be in registered ports range (1024-49151).", help=f"Maximum port number for Dynamo services (default: {DEFAULT_DYNAMO_PORT_MAX}). Must be in registered ports range (1024-49151).",
) )
parser.add_argument(
"--connector",
nargs="*",
default=["nixl"],
help="List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -141,6 +152,36 @@ def parse_args() -> Config: ...@@ -141,6 +152,36 @@ def parse_args() -> Config:
min=args.dynamo_port_min, max=args.dynamo_port_max min=args.dynamo_port_min, max=args.dynamo_port_max
) )
# Check for conflicting flags
has_kv_transfer_config = (
hasattr(engine_args, "kv_transfer_config")
and engine_args.kv_transfer_config is not None
)
has_connector_flag = args.connector is not None
if has_kv_transfer_config and has_connector_flag:
raise ValueError(
"Cannot specify both --kv-transfer-config and --connector flags"
)
if has_connector_flag:
normalized = [c.lower() for c in args.connector]
invalid = [c for c in normalized if c not in VALID_CONNECTORS]
if invalid:
raise ValueError(
f"Invalid connector(s): {', '.join(invalid)}. Valid options are: {', '.join(sorted(VALID_CONNECTORS))}"
)
if "none" in normalized or "null" in normalized:
if len(normalized) > 1:
raise ValueError(
"'none' and 'null' cannot be combined with other connectors"
)
config.connector_list = []
else:
config.connector_list = normalized
if config.engine_args.block_size is None: if config.engine_args.block_size is None:
config.engine_args.block_size = 16 config.engine_args.block_size = 16
logger.debug( logger.debug(
...@@ -169,15 +210,21 @@ async def configure_ports_with_etcd(config: Config, etcd_client): ...@@ -169,15 +210,21 @@ async def configure_ports_with_etcd(config: Config, etcd_client):
config.kv_port = kv_port config.kv_port = kv_port
logger.info(f"Allocated ZMQ KV events port: {kv_port} (worker_id={worker_id})") logger.info(f"Allocated ZMQ KV events port: {kv_port} (worker_id={worker_id})")
# Check if NIXL is needed based on connector list
needs_nixl = config.connector_list and "nixl" in config.connector_list
if needs_nixl:
# Allocate side channel ports # Allocate side channel ports
# https://github.com/vllm-project/vllm/blob/releases/v0.10.1/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L443 # https://github.com/vllm-project/vllm/blob/releases/v0.10.0/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L372
# NIXL calculates ports as: base_port + (dp_rank * tp_size) + tp_rank # NIXL calculates ports as: base_port + (dp_rank * tp_size) + tp_rank
# For dp_rank, we need to reserve tp_size consecutive ports # For dp_rank, we need to reserve tp_size consecutive ports
tp_size = config.engine_args.tensor_parallel_size or 1 tp_size = config.engine_args.tensor_parallel_size or 1
# The first port for this dp_rank will be at: base_port + (dp_rank * tp_size) # The first port for this dp_rank will be at: base_port + (dp_rank * tp_size)
# We need to allocate tp_size consecutive ports starting from there # We need to allocate tp_size consecutive ports starting from there
nixl_metadata = PortMetadata(worker_id=worker_id, reason="nixl_side_channel_port") nixl_metadata = PortMetadata(
worker_id=worker_id, reason="nixl_side_channel_port"
)
nixl_request = PortAllocationRequest( nixl_request = PortAllocationRequest(
etcd_context=etcd_context, etcd_context=etcd_context,
metadata=nixl_metadata, metadata=nixl_metadata,
...@@ -200,53 +247,64 @@ async def configure_ports_with_etcd(config: Config, etcd_client): ...@@ -200,53 +247,64 @@ async def configure_ports_with_etcd(config: Config, etcd_client):
f"Consider using a higher port range." f"Consider using a higher port range."
) )
config.side_channel_port = base_side_channel_port
logger.info( logger.info(
f"Allocated NIXL side channel ports: base={base_side_channel_port}, " f"Allocated NIXL side channel ports: base={base_side_channel_port}, "
f"allocated_ports={allocated_ports} (worker_id={worker_id}, dp_rank={dp_rank}, tp_size={tp_size})" f"allocated_ports={allocated_ports} (worker_id={worker_id}, dp_rank={dp_rank}, tp_size={tp_size})"
) )
set_side_channel_host_and_port(base_side_channel_port)
def overwrite_args(config): def create_kv_transfer_config(config: Config) -> Optional[KVTransferConfig]:
"""Set vLLM defaults for Dynamo.""" """Create KVTransferConfig based on user config or connector list.
assert (
config.side_channel_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"
dp_rank = config.engine_args.data_parallel_rank or 0 Handles logging and returns the appropriate config or None.
"""
has_user_kv_config = (
hasattr(config.engine_args, "kv_transfer_config")
and config.engine_args.kv_transfer_config is not None
)
if has_user_kv_config:
logger.info("Using user-provided kv_transfer_config from --kv-transfer-config")
return None # Let vLLM use the user's config
# No connector list or empty list means no config
if not config.connector_list:
logger.info("Using vLLM defaults for kv_transfer_config")
return None
logger.info(f"Creating kv_transfer_config from --connector {config.connector_list}")
# Create connector configs in specified order
multi_connectors = []
for connector in config.connector_list:
if connector == "lmcache":
connector_cfg = {"kv_connector": "LMCacheConnectorV1", "kv_role": "kv_both"}
elif connector == "nixl":
connector_cfg = {"kv_connector": "NixlConnector", "kv_role": "kv_both"}
elif connector == "kvbm":
connector_cfg = {
"kv_connector": "DynamoConnector",
"kv_connector_module_path": "dynamo.llm.vllm_integration.connector",
"kv_role": "kv_both",
}
multi_connectors.append(connector_cfg)
# For single connector, return direct config
if len(multi_connectors) == 1:
cfg = multi_connectors[0]
return KVTransferConfig(**cfg)
# Set kv_transfer_config based on LMCache setting # For multiple connectors, use MultiConnector
if ENABLE_LMCACHE: return KVTransferConfig(
if config.is_prefill_worker:
# Prefill worker use LMCache with disaggregated serving (MultiConnector) for disaggregated serving
kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector", kv_connector="MultiConnector",
kv_role="kv_both", kv_role="kv_both",
kv_connector_extra_config={ kv_connector_extra_config={"connectors": multi_connectors},
"connectors": [
{"kv_connector": "LMCacheConnectorV1", "kv_role": "kv_both"},
{
"kv_connector": "NixlConnector",
"kv_role": "kv_both",
},
]
},
)
logger.info("Using LMCache with MultiConnector serving")
else:
# If enable lmcache, single node in default uses single connector serving
kv_transfer_config = KVTransferConfig(
kv_connector="LMCacheConnectorV1", kv_role="kv_both"
) )
logger.info("Using LMCache with LMCacheConnector serving")
else:
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
)
logger.info("Using NixlConnector configuration")
def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
defaults = { defaults = {
"task": "generate", "task": "generate",
# As of vLLM >=0.10.0 the engine unconditionally calls # As of vLLM >=0.10.0 the engine unconditionally calls
...@@ -257,11 +315,14 @@ def overwrite_args(config): ...@@ -257,11 +315,14 @@ def overwrite_args(config):
"disable_log_requests": True, "disable_log_requests": True,
# KV routing relies on logging KV metrics # KV routing relies on logging KV metrics
"disable_log_stats": False, "disable_log_stats": False,
"kv_transfer_config": kv_transfer_config,
} }
kv_config = create_kv_transfer_config(config)
if kv_config:
defaults["kv_transfer_config"] = kv_config
if config.engine_args.enable_prefix_caching: if config.engine_args.enable_prefix_caching:
# If caching, send events dp_rank = config.engine_args.data_parallel_rank or 0
defaults |= { defaults |= {
# Always setting up kv events if enable prefix cache. # Always setting up kv events if enable prefix cache.
"kv_events_config": KVEventsConfig( "kv_events_config": KVEventsConfig(
...@@ -271,8 +332,6 @@ def overwrite_args(config): ...@@ -271,8 +332,6 @@ def overwrite_args(config):
) )
} }
set_side_channel_host_and_port(config)
logger.debug("Setting Dynamo defaults for vLLM") logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items(): for key, value in defaults.items():
if hasattr(config.engine_args, key): if hasattr(config.engine_args, key):
...@@ -282,11 +341,11 @@ def overwrite_args(config): ...@@ -282,11 +341,11 @@ def overwrite_args(config):
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.") raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")
def set_side_channel_host_and_port(config: Config): def set_side_channel_host_and_port(side_channel_port: int):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors. """vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel. This sets the port number for the side channel.
""" """
host_ip = get_host_ip() host_ip = get_host_ip()
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port) os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(side_channel_port)
logger.debug(f"Set NIXL side channel to {host_ip}:{config.side_channel_port}") logger.debug(f"Set NIXL side channel to {host_ip}:{side_channel_port}")
...@@ -80,7 +80,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -80,7 +80,7 @@ async def worker(runtime: DistributedRuntime):
for sig in (signal.SIGTERM, signal.SIGINT): for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler) loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown") logging.debug("Signal handlers set up for graceful shutdown")
if config.is_prefill_worker: if config.is_prefill_worker:
await init_prefill(runtime, config) await init_prefill(runtime, config)
...@@ -99,7 +99,7 @@ def setup_vllm_engine(config, stat_logger=None): ...@@ -99,7 +99,7 @@ def setup_vllm_engine(config, stat_logger=None):
setup_lmcache_environment() setup_lmcache_environment()
logger.info("LMCache enabled for VllmWorker") logger.info("LMCache enabled for VllmWorker")
else: else:
logger.info("LMCache is disabled") logger.debug("LMCache is disabled")
# Load default sampling params from `generation_config.json` # Load default sampling params from `generation_config.json`
default_sampling_params = ( default_sampling_params = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment