Unverified Commit bc290e7c authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: allow user to adjust kv_transfer_config (#2517)

parent e31c8790
...@@ -163,6 +163,7 @@ vLLM workers are configured through command-line arguments. Key parameters inclu ...@@ -163,6 +163,7 @@ vLLM workers are configured through command-line arguments. Key parameters inclu
- `--model`: Model to serve (e.g., `Qwen/Qwen3-0.6B`) - `--model`: Model to serve (e.g., `Qwen/Qwen3-0.6B`)
- `--is-prefill-worker`: Enable prefill-only mode for disaggregated serving - `--is-prefill-worker`: Enable prefill-only mode for disaggregated serving
- `--metrics-endpoint-port`: Port for publishing KV metrics to Dynamo - `--metrics-endpoint-port`: Port for publishing KV metrics to Dynamo
- `--connector`: Specify which kv_transfer_config you want vllm to use `[nixl, lmcache, kvbm, none]`. This is a helper flag which overwrites the engines KVTransferConfig.
See `args.py` for the full list of configuration options and their defaults. See `args.py` for the full list of configuration options and their defaults.
......
...@@ -9,4 +9,4 @@ python -m dynamo.frontend & ...@@ -9,4 +9,4 @@ python -m dynamo.frontend &
# run worker # run worker
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
python -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --no-enable-prefix-caching python -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none
...@@ -12,4 +12,4 @@ ENABLE_LMCACHE=1 \ ...@@ -12,4 +12,4 @@ ENABLE_LMCACHE=1 \
LMCACHE_CHUNK_SIZE=256 \ LMCACHE_CHUNK_SIZE=256 \
LMCACHE_LOCAL_CPU=True \ LMCACHE_LOCAL_CPU=True \
LMCACHE_MAX_LOCAL_CPU_SIZE=20 \ LMCACHE_MAX_LOCAL_CPU_SIZE=20 \
python -m dynamo.vllm --model Qwen/Qwen3-0.6B python -m dynamo.vllm --model Qwen/Qwen3-0.6B
\ No newline at end of file
...@@ -9,6 +9,6 @@ python -m dynamo.frontend --router-mode kv & ...@@ -9,6 +9,6 @@ python -m dynamo.frontend --router-mode kv &
# run workers # run workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager & CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none &
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none
...@@ -21,4 +21,5 @@ LMCACHE_MAX_LOCAL_CPU_SIZE=20 \ ...@@ -21,4 +21,5 @@ LMCACHE_MAX_LOCAL_CPU_SIZE=20 \
CUDA_VISIBLE_DEVICES=1 \ CUDA_VISIBLE_DEVICES=1 \
python3 -m dynamo.vllm \ python3 -m dynamo.vllm \
--model Qwen/Qwen3-0.6B \ --model Qwen/Qwen3-0.6B \
--is-prefill-worker --is-prefill-worker \
\ No newline at end of file --connector lmcache nixl
...@@ -30,6 +30,8 @@ logger = logging.getLogger(__name__) ...@@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B" DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
VALID_CONNECTORS = {"nixl", "lmcache", "kvbm", "null", "none"}
# Global LMCache configuration - initialize once on module import # Global LMCache configuration - initialize once on module import
ENABLE_LMCACHE = os.getenv("ENABLE_LMCACHE", "0").lower() in ("1", "true", "yes") ENABLE_LMCACHE = os.getenv("ENABLE_LMCACHE", "0").lower() in ("1", "true", "yes")
...@@ -44,7 +46,6 @@ class Config: ...@@ -44,7 +46,6 @@ class Config:
is_prefill_worker: bool is_prefill_worker: bool
migration_limit: int = 0 migration_limit: int = 0
kv_port: Optional[int] = None kv_port: Optional[int] = None
side_channel_port: Optional[int] = None
port_range: DynamoPortRange port_range: DynamoPortRange
# mirror vLLM # mirror vLLM
...@@ -54,6 +55,9 @@ class Config: ...@@ -54,6 +55,9 @@ class Config:
# rest vLLM args # rest vLLM args
engine_args: AsyncEngineArgs engine_args: AsyncEngineArgs
# Connector list from CLI
connector_list: Optional[list] = None
def parse_args() -> Config: def parse_args() -> Config:
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
...@@ -91,6 +95,13 @@ def parse_args() -> Config: ...@@ -91,6 +95,13 @@ def parse_args() -> Config:
default=DEFAULT_DYNAMO_PORT_MAX, default=DEFAULT_DYNAMO_PORT_MAX,
help=f"Maximum port number for Dynamo services (default: {DEFAULT_DYNAMO_PORT_MAX}). Must be in registered ports range (1024-49151).", help=f"Maximum port number for Dynamo services (default: {DEFAULT_DYNAMO_PORT_MAX}). Must be in registered ports range (1024-49151).",
) )
parser.add_argument(
"--connector",
nargs="*",
default=["nixl"],
help="List of connectors to use in order (e.g., --connector nixl lmcache). "
"Options: nixl, lmcache, kvbm, null, none. Default: nixl. Order will be preserved in MultiConnector.",
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -141,6 +152,36 @@ def parse_args() -> Config: ...@@ -141,6 +152,36 @@ def parse_args() -> Config:
min=args.dynamo_port_min, max=args.dynamo_port_max min=args.dynamo_port_min, max=args.dynamo_port_max
) )
# Check for conflicting flags
has_kv_transfer_config = (
hasattr(engine_args, "kv_transfer_config")
and engine_args.kv_transfer_config is not None
)
has_connector_flag = args.connector is not None
if has_kv_transfer_config and has_connector_flag:
raise ValueError(
"Cannot specify both --kv-transfer-config and --connector flags"
)
if has_connector_flag:
normalized = [c.lower() for c in args.connector]
invalid = [c for c in normalized if c not in VALID_CONNECTORS]
if invalid:
raise ValueError(
f"Invalid connector(s): {', '.join(invalid)}. Valid options are: {', '.join(sorted(VALID_CONNECTORS))}"
)
if "none" in normalized or "null" in normalized:
if len(normalized) > 1:
raise ValueError(
"'none' and 'null' cannot be combined with other connectors"
)
config.connector_list = []
else:
config.connector_list = normalized
if config.engine_args.block_size is None: if config.engine_args.block_size is None:
config.engine_args.block_size = 16 config.engine_args.block_size = 16
logger.debug( logger.debug(
...@@ -169,84 +210,101 @@ async def configure_ports_with_etcd(config: Config, etcd_client): ...@@ -169,84 +210,101 @@ async def configure_ports_with_etcd(config: Config, etcd_client):
config.kv_port = kv_port config.kv_port = kv_port
logger.info(f"Allocated ZMQ KV events port: {kv_port} (worker_id={worker_id})") logger.info(f"Allocated ZMQ KV events port: {kv_port} (worker_id={worker_id})")
# Allocate side channel ports # Check if NIXL is needed based on connector list
# https://github.com/vllm-project/vllm/blob/releases/v0.10.1/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L443 needs_nixl = config.connector_list and "nixl" in config.connector_list
# NIXL calculates ports as: base_port + (dp_rank * tp_size) + tp_rank
# For dp_rank, we need to reserve tp_size consecutive ports
tp_size = config.engine_args.tensor_parallel_size or 1
# The first port for this dp_rank will be at: base_port + (dp_rank * tp_size)
# We need to allocate tp_size consecutive ports starting from there
nixl_metadata = PortMetadata(worker_id=worker_id, reason="nixl_side_channel_port")
nixl_request = PortAllocationRequest(
etcd_context=etcd_context,
metadata=nixl_metadata,
port_range=config.port_range,
block_size=tp_size,
)
allocated_ports = await allocate_and_reserve_port_block(nixl_request)
first_port_for_dp_rank = allocated_ports[0]
# Calculate the base port that NIXL expects if needs_nixl:
# base_port = first_port_for_dp_rank - (dp_rank * tp_size) # Allocate side channel ports
nixl_offset = dp_rank * tp_size # https://github.com/vllm-project/vllm/blob/releases/v0.10.0/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py#L372
base_side_channel_port = first_port_for_dp_rank - nixl_offset # NIXL calculates ports as: base_port + (dp_rank * tp_size) + tp_rank
# For dp_rank, we need to reserve tp_size consecutive ports
tp_size = config.engine_args.tensor_parallel_size or 1
if base_side_channel_port < 0: # The first port for this dp_rank will be at: base_port + (dp_rank * tp_size)
raise ValueError( # We need to allocate tp_size consecutive ports starting from there
f"NIXL base port calculation resulted in negative port: " nixl_metadata = PortMetadata(
f"first_allocated_port={first_port_for_dp_rank}, offset={nixl_offset}, " worker_id=worker_id, reason="nixl_side_channel_port"
f"base_port={base_side_channel_port}. Current range: {config.port_range.min}-{config.port_range.max}. "
f"Consider using a higher port range."
) )
nixl_request = PortAllocationRequest(
etcd_context=etcd_context,
metadata=nixl_metadata,
port_range=config.port_range,
block_size=tp_size,
)
allocated_ports = await allocate_and_reserve_port_block(nixl_request)
first_port_for_dp_rank = allocated_ports[0]
# Calculate the base port that NIXL expects
# base_port = first_port_for_dp_rank - (dp_rank * tp_size)
nixl_offset = dp_rank * tp_size
base_side_channel_port = first_port_for_dp_rank - nixl_offset
if base_side_channel_port < 0:
raise ValueError(
f"NIXL base port calculation resulted in negative port: "
f"first_allocated_port={first_port_for_dp_rank}, offset={nixl_offset}, "
f"base_port={base_side_channel_port}. Current range: {config.port_range.min}-{config.port_range.max}. "
f"Consider using a higher port range."
)
config.side_channel_port = base_side_channel_port logger.info(
f"Allocated NIXL side channel ports: base={base_side_channel_port}, "
logger.info( f"allocated_ports={allocated_ports} (worker_id={worker_id}, dp_rank={dp_rank}, tp_size={tp_size})"
f"Allocated NIXL side channel ports: base={base_side_channel_port}, " )
f"allocated_ports={allocated_ports} (worker_id={worker_id}, dp_rank={dp_rank}, tp_size={tp_size})" set_side_channel_host_and_port(base_side_channel_port)
)
def overwrite_args(config): def create_kv_transfer_config(config: Config) -> Optional[KVTransferConfig]:
"""Set vLLM defaults for Dynamo.""" """Create KVTransferConfig based on user config or connector list.
assert (
config.side_channel_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"
dp_rank = config.engine_args.data_parallel_rank or 0 Handles logging and returns the appropriate config or None.
"""
has_user_kv_config = (
hasattr(config.engine_args, "kv_transfer_config")
and config.engine_args.kv_transfer_config is not None
)
# Set kv_transfer_config based on LMCache setting if has_user_kv_config:
if ENABLE_LMCACHE: logger.info("Using user-provided kv_transfer_config from --kv-transfer-config")
if config.is_prefill_worker: return None # Let vLLM use the user's config
# Prefill worker use LMCache with disaggregated serving (MultiConnector) for disaggregated serving
kv_transfer_config = KVTransferConfig( # No connector list or empty list means no config
kv_connector="MultiConnector", if not config.connector_list:
kv_role="kv_both", logger.info("Using vLLM defaults for kv_transfer_config")
kv_connector_extra_config={ return None
"connectors": [
{"kv_connector": "LMCacheConnectorV1", "kv_role": "kv_both"}, logger.info(f"Creating kv_transfer_config from --connector {config.connector_list}")
{
"kv_connector": "NixlConnector", # Create connector configs in specified order
"kv_role": "kv_both", multi_connectors = []
}, for connector in config.connector_list:
] if connector == "lmcache":
}, connector_cfg = {"kv_connector": "LMCacheConnectorV1", "kv_role": "kv_both"}
) elif connector == "nixl":
logger.info("Using LMCache with MultiConnector serving") connector_cfg = {"kv_connector": "NixlConnector", "kv_role": "kv_both"}
else: elif connector == "kvbm":
# If enable lmcache, single node in default uses single connector serving connector_cfg = {
kv_transfer_config = KVTransferConfig( "kv_connector": "DynamoConnector",
kv_connector="LMCacheConnectorV1", kv_role="kv_both" "kv_connector_module_path": "dynamo.llm.vllm_integration.connector",
) "kv_role": "kv_both",
logger.info("Using LMCache with LMCacheConnector serving") }
multi_connectors.append(connector_cfg)
# For single connector, return direct config
if len(multi_connectors) == 1:
cfg = multi_connectors[0]
return KVTransferConfig(**cfg)
# For multiple connectors, use MultiConnector
return KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={"connectors": multi_connectors},
)
else:
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
)
logger.info("Using NixlConnector configuration")
def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
defaults = { defaults = {
"task": "generate", "task": "generate",
# As of vLLM >=0.10.0 the engine unconditionally calls # As of vLLM >=0.10.0 the engine unconditionally calls
...@@ -257,11 +315,14 @@ def overwrite_args(config): ...@@ -257,11 +315,14 @@ def overwrite_args(config):
"disable_log_requests": True, "disable_log_requests": True,
# KV routing relies on logging KV metrics # KV routing relies on logging KV metrics
"disable_log_stats": False, "disable_log_stats": False,
"kv_transfer_config": kv_transfer_config,
} }
kv_config = create_kv_transfer_config(config)
if kv_config:
defaults["kv_transfer_config"] = kv_config
if config.engine_args.enable_prefix_caching: if config.engine_args.enable_prefix_caching:
# If caching, send events dp_rank = config.engine_args.data_parallel_rank or 0
defaults |= { defaults |= {
# Always setting up kv events if enable prefix cache. # Always setting up kv events if enable prefix cache.
"kv_events_config": KVEventsConfig( "kv_events_config": KVEventsConfig(
...@@ -271,8 +332,6 @@ def overwrite_args(config): ...@@ -271,8 +332,6 @@ def overwrite_args(config):
) )
} }
set_side_channel_host_and_port(config)
logger.debug("Setting Dynamo defaults for vLLM") logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items(): for key, value in defaults.items():
if hasattr(config.engine_args, key): if hasattr(config.engine_args, key):
...@@ -282,11 +341,11 @@ def overwrite_args(config): ...@@ -282,11 +341,11 @@ def overwrite_args(config):
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.") raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")
def set_side_channel_host_and_port(config: Config): def set_side_channel_host_and_port(side_channel_port: int):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors. """vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel. This sets the port number for the side channel.
""" """
host_ip = get_host_ip() host_ip = get_host_ip()
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port) os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(side_channel_port)
logger.debug(f"Set NIXL side channel to {host_ip}:{config.side_channel_port}") logger.debug(f"Set NIXL side channel to {host_ip}:{side_channel_port}")
...@@ -80,7 +80,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -80,7 +80,7 @@ async def worker(runtime: DistributedRuntime):
for sig in (signal.SIGTERM, signal.SIGINT): for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler) loop.add_signal_handler(sig, signal_handler)
logging.info("Signal handlers set up for graceful shutdown") logging.debug("Signal handlers set up for graceful shutdown")
if config.is_prefill_worker: if config.is_prefill_worker:
await init_prefill(runtime, config) await init_prefill(runtime, config)
...@@ -99,7 +99,7 @@ def setup_vllm_engine(config, stat_logger=None): ...@@ -99,7 +99,7 @@ def setup_vllm_engine(config, stat_logger=None):
setup_lmcache_environment() setup_lmcache_environment()
logger.info("LMCache enabled for VllmWorker") logger.info("LMCache enabled for VllmWorker")
else: else:
logger.info("LMCache is disabled") logger.debug("LMCache is disabled")
# Load default sampling params from `generation_config.json` # Load default sampling params from `generation_config.json`
default_sampling_params = ( default_sampling_params = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment