Unverified Commit 1e3e4a0c authored by Alec's avatar Alec Committed by GitHub
Browse files

fix: port race condition through deterministic ports (#1937)

parent 4ad281f2
...@@ -13,9 +13,13 @@ ...@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import json
import logging import logging
import os
import socket import socket
import sys import sys
import time
from typing import Optional from typing import Optional
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
...@@ -30,14 +34,6 @@ DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate" ...@@ -30,14 +34,6 @@ DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B" DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
def find_free_port() -> int:
"""Find a free port by binding to port 0."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
port = s.getsockname()[1]
return port
class Config: class Config:
"""Command line parameters or defaults""" """Command line parameters or defaults"""
...@@ -45,8 +41,9 @@ class Config: ...@@ -45,8 +41,9 @@ class Config:
namespace: str namespace: str
component: str component: str
endpoint: str endpoint: str
kv_events_port: int
is_prefill_worker: bool is_prefill_worker: bool
kv_port: Optional[int] = None
side_channel_port: Optional[int] = None
# mirror vLLM # mirror vLLM
model: str model: str
...@@ -56,38 +53,6 @@ class Config: ...@@ -56,38 +53,6 @@ class Config:
engine_args: AsyncEngineArgs engine_args: AsyncEngineArgs
def overwrite_args(config):
defaults = {
"task": "generate",
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
# Always set up KV Events for routing
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_events_port}",
),
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
),
}
# Made decision to always overwrite.
# Respecting users original cmd line args at all costs requires a bunch of arg parse work
logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")
def parse_args() -> Config: def parse_args() -> Config:
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM server integrated with Dynamo LLM." description="vLLM server integrated with Dynamo LLM."
...@@ -103,12 +68,6 @@ def parse_args() -> Config: ...@@ -103,12 +68,6 @@ def parse_args() -> Config:
action="store_true", action="store_true",
help="Enable prefill functionality for this worker. Currently overwrites the --endpoint to be a specially chosen dyn://dynamo.prefill.generate", help="Enable prefill functionality for this worker. Currently overwrites the --endpoint to be a specially chosen dyn://dynamo.prefill.generate",
) )
parser.add_argument(
"--kv-events-port",
type=int,
default=find_free_port(),
help="Endpoint where vLLM publishes metrics for dynamo. For DP, we handle the port iteration.",
)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
...@@ -143,7 +102,6 @@ def parse_args() -> Config: ...@@ -143,7 +102,6 @@ def parse_args() -> Config:
config.endpoint = parsed_endpoint_name config.endpoint = parsed_endpoint_name
config.engine_args = engine_args config.engine_args = engine_args
config.is_prefill_worker = args.is_prefill_worker config.is_prefill_worker = args.is_prefill_worker
config.kv_events_port = args.kv_events_port
if config.engine_args.block_size is None: if config.engine_args.block_size is None:
config.engine_args.block_size = 16 config.engine_args.block_size = 16
...@@ -151,6 +109,153 @@ def parse_args() -> Config: ...@@ -151,6 +109,153 @@ def parse_args() -> Config:
f"Setting reasonable default of {config.engine_args.block_size} for block_size" f"Setting reasonable default of {config.engine_args.block_size} for block_size"
) )
overwrite_args(config)
return config return config
async def allocate_and_reserve_port(
namespace,
etcd_client,
worker_id: str,
reason: str,
max_attempts: int = 100,
) -> int:
"""
Get an OS-assigned port and atomically reserve it in ETCD.
Retries until successful or max_attempts reached.
Args:
max_attempts: Maximum number of ports to try (default: 100)
Raises:
RuntimeError: If unable to reserve a port within max_attempts
OSError: If unable to create sockets (system resource issues)
"""
node_name = socket.gethostname()
for attempt in range(1, max_attempts + 1):
# Hold socket open just long enough to reserve in ETCD
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("", 0))
port = sock.getsockname()[1]
# Reserve in ETCD while holding the socket
key = f"dyn://{namespace}/ports/{node_name}/{port}"
value = {
"worker_id": worker_id,
"reason": reason,
"reserved_at": time.time(),
"pid": os.getpid(),
}
try:
await etcd_client.kv_create(
key=key,
value=json.dumps(value).encode(),
lease_id=etcd_client.primary_lease_id(),
)
logger.debug(f"Reserved OS-assigned port {port} for {worker_id}")
return port
except Exception as e:
logger.debug(
f"Port {port} on {node_name} was already reserved (attempt {attempt}): {e}"
)
if attempt < max_attempts:
await asyncio.sleep(0.01)
raise RuntimeError(
f"Failed to allocate and reserve a port after {max_attempts} attempts"
)
async def configure_ports_with_etcd(config: Config, etcd_client):
"""Configure all settings that require ETCD, including port allocation and vLLM overrides."""
# First, allocate ports
dp_rank = config.engine_args.data_parallel_rank or 0
worker_id = f"vllm-{config.component}-dp{dp_rank}"
# Allocate KV events port
kv_port = await allocate_and_reserve_port(
namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}",
reason="zmq_kv_event_port",
)
# Allocate side channel port
side_channel_port = await allocate_and_reserve_port(
namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}",
reason="nixl_side_channel_port",
)
# Update config with allocated ports
config.kv_port = kv_port
config.side_channel_port = side_channel_port
def overwrite_args(config):
"""Set vLLM defaults for Dynamo."""
assert (
config.kv_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"
assert (
config.side_channel_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"
dp_rank = config.engine_args.data_parallel_rank or 0
defaults = {
"task": "generate",
"skip_tokenizer_init": True,
"disable_log_requests": True,
"enable_prefix_caching": True,
# KV routing relies on logging KV metrics
"disable_log_stats": False,
# Always setting up kv transfer for disagg
"kv_transfer_config": KVTransferConfig(
kv_connector="NixlConnector", kv_role="kv_both"
),
"kv_events_config": KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{config.kv_port - dp_rank}", # vLLM will iterate dp_rank for us, so we need to subtract it out TODO: fix in vLLM
),
}
set_side_channel_host_and_port(config)
logger.debug("Setting Dynamo defaults for vLLM")
for key, value in defaults.items():
if hasattr(config.engine_args, key):
setattr(config.engine_args, key, value)
logger.debug(f" engine_args.{key} = {value}")
else:
raise ValueError(f"{key} not found in AsyncEngineArgs from vLLM.")
def set_side_channel_host_and_port(config: Config, hostname: Optional[str] = None):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if hostname is None:
hostname = socket.gethostname()
# Test if hostname is usable by attempting to bind to it
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((hostname, 0))
except (socket.error, socket.gaierror):
# If hostname is not usable, fall back to localhost
logger.warning(
f"Hostname '{hostname}' is not usable, falling back to '127.0.0.1'"
)
hostname = "127.0.0.1"
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(config.side_channel_port)
logger.debug(f"Set NIXL side channel to {hostname}:{config.side_channel_port}")
...@@ -17,11 +17,9 @@ import asyncio ...@@ -17,11 +17,9 @@ import asyncio
import logging import logging
import os import os
import signal import signal
import socket
from typing import Optional
import uvloop import uvloop
from args import Config, find_free_port, parse_args from args import Config, configure_ports_with_etcd, overwrite_args, parse_args
from handlers import DecodeWorkerHandler, PrefillWorkerHandler from handlers import DecodeWorkerHandler, PrefillWorkerHandler
from publisher import StatLoggerFactory from publisher import StatLoggerFactory
from vllm.distributed.kv_events import ZmqEventPublisher from vllm.distributed.kv_events import ZmqEventPublisher
...@@ -57,6 +55,10 @@ async def graceful_shutdown(runtime): ...@@ -57,6 +55,10 @@ async def graceful_shutdown(runtime):
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
config = parse_args() config = parse_args()
etcd_client = runtime.etcd_client()
await configure_ports_with_etcd(config, etcd_client)
overwrite_args(config)
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
...@@ -78,8 +80,6 @@ def setup_vllm_engine(config, stat_logger=None): ...@@ -78,8 +80,6 @@ def setup_vllm_engine(config, stat_logger=None):
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
set_side_channel_host_and_port()
engine_args = config.engine_args engine_args = config.engine_args
# Load default sampling params from `generation_config.json` # Load default sampling params from `generation_config.json`
default_sampling_params = ( default_sampling_params = (
...@@ -105,32 +105,6 @@ def setup_vllm_engine(config, stat_logger=None): ...@@ -105,32 +105,6 @@ def setup_vllm_engine(config, stat_logger=None):
return engine_client, vllm_config, default_sampling_params return engine_client, vllm_config, default_sampling_params
def set_side_channel_host_and_port(
hostname: Optional[str] = None, port: Optional[int] = None
):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if hostname is None:
hostname = socket.gethostname()
# Test if hostname is usable by attempting to bind to it
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((hostname, 0))
except (socket.error, socket.gaierror):
# If hostname is not usable, fall back to localhost
logger.warning(
f"Hostname '{hostname}' is not usable, falling back to '127.0.0.1'"
)
hostname = "127.0.0.1"
if port is None:
port = find_free_port()
logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_HOST to %s", hostname)
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = hostname
logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_PORT to %s", port)
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port)
async def init_prefill(runtime: DistributedRuntime, config: Config): async def init_prefill(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
......
...@@ -16,8 +16,7 @@ for i in {0..3}; do ...@@ -16,8 +16,7 @@ for i in {0..3}; do
--data-parallel-rank $i \ --data-parallel-rank $i \
--data-parallel-size 4 \ --data-parallel-size 4 \
--enable-expert-parallel \ --enable-expert-parallel \
--enforce-eager \ --enforce-eager &
--kv-events-port 49500 &
done done
echo "All workers starting. (press Ctrl+C to stop)..." echo "All workers starting. (press Ctrl+C to stop)..."
......
...@@ -98,8 +98,7 @@ for ((i=0; i<GPUS_PER_NODE; i++)); do ...@@ -98,8 +98,7 @@ for ((i=0; i<GPUS_PER_NODE; i++)); do
--data-parallel-address $MASTER_ADDR \ --data-parallel-address $MASTER_ADDR \
--data-parallel-rpc-port 13345 \ --data-parallel-rpc-port 13345 \
--gpu-memory-utilization 0.95 \ --gpu-memory-utilization 0.95 \
--enforce-eager \ --enforce-eager 2>&1 | tee $LOG_DIR/dsr1_dep_${dp_rank}.log &
--kv-events-port 49700 2>&1 | tee $LOG_DIR/dsr1_dep_${dp_rank}.log &
done done
echo "All workers starting. (press Ctrl+C to stop)..." echo "All workers starting. (press Ctrl+C to stop)..."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment