Unverified Commit 3865a941 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Port vllm port allocator to Rust in bindings (#3125)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 19948b7f
...@@ -12,13 +12,13 @@ from vllm.engine.arg_utils import AsyncEngineArgs ...@@ -12,13 +12,13 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.runtime import DistributedRuntime
from . import __version__ from . import __version__
from .ports import ( from .ports import (
DEFAULT_DYNAMO_PORT_MAX, DEFAULT_DYNAMO_PORT_MAX,
DEFAULT_DYNAMO_PORT_MIN, DEFAULT_DYNAMO_PORT_MIN,
DynamoPortRange, DynamoPortRange,
EtcdContext,
PortAllocationRequest, PortAllocationRequest,
PortMetadata, PortMetadata,
allocate_and_reserve_port, allocate_and_reserve_port,
...@@ -195,10 +195,8 @@ def parse_args() -> Config: ...@@ -195,10 +195,8 @@ def parse_args() -> Config:
return config return config
async def configure_ports_with_etcd(config: Config, etcd_client): async def configure_ports(runtime: DistributedRuntime, config: Config):
"""Configure all settings that require ETCD, including port allocation and vLLM overrides.""" """Configure including port allocation and vLLM overrides."""
etcd_context = EtcdContext(client=etcd_client, namespace=config.namespace)
dp_rank = config.engine_args.data_parallel_rank or 0 dp_rank = config.engine_args.data_parallel_rank or 0
worker_id = f"vllm-{config.component}-dp{dp_rank}" worker_id = f"vllm-{config.component}-dp{dp_rank}"
...@@ -207,7 +205,8 @@ async def configure_ports_with_etcd(config: Config, etcd_client): ...@@ -207,7 +205,8 @@ async def configure_ports_with_etcd(config: Config, etcd_client):
if config.engine_args.enable_prefix_caching: if config.engine_args.enable_prefix_caching:
kv_metadata = PortMetadata(worker_id=worker_id, reason="zmq_kv_event_port") kv_metadata = PortMetadata(worker_id=worker_id, reason="zmq_kv_event_port")
kv_port = await allocate_and_reserve_port( kv_port = await allocate_and_reserve_port(
etcd_context=etcd_context, runtime=runtime,
namespace=config.namespace,
metadata=kv_metadata, metadata=kv_metadata,
port_range=config.port_range, port_range=config.port_range,
) )
...@@ -230,12 +229,13 @@ async def configure_ports_with_etcd(config: Config, etcd_client): ...@@ -230,12 +229,13 @@ async def configure_ports_with_etcd(config: Config, etcd_client):
worker_id=worker_id, reason="nixl_side_channel_port" worker_id=worker_id, reason="nixl_side_channel_port"
) )
nixl_request = PortAllocationRequest( nixl_request = PortAllocationRequest(
etcd_context=etcd_context,
metadata=nixl_metadata, metadata=nixl_metadata,
port_range=config.port_range, port_range=config.port_range,
block_size=tp_size, block_size=tp_size,
) )
allocated_ports = await allocate_and_reserve_port_block(nixl_request) allocated_ports = await allocate_and_reserve_port_block(
runtime, config.namespace, nixl_request
)
first_port_for_dp_rank = allocated_ports[0] first_port_for_dp_rank = allocated_ports[0]
# Calculate the base port that NIXL expects # Calculate the base port that NIXL expects
...@@ -273,7 +273,7 @@ def create_kv_events_config(config: Config) -> Optional[KVEventsConfig]: ...@@ -273,7 +273,7 @@ def create_kv_events_config(config: Config) -> Optional[KVEventsConfig]:
logger.info("Creating Dynamo default kv_events_config for prefix caching") logger.info("Creating Dynamo default kv_events_config for prefix caching")
if config.kv_port is None: if config.kv_port is None:
raise ValueError( raise ValueError(
"config.kv_port is not set; call configure_ports_with_etcd(...) before overwrite_args " "config.kv_port is not set; call configure_ports(...) before overwrite_args "
"or provide --kv-event-config to supply an explicit endpoint." "or provide --kv-event-config to supply an explicit endpoint."
) )
dp_rank = config.engine_args.data_parallel_rank or 0 dp_rank = config.engine_args.data_parallel_rank or 0
......
...@@ -22,13 +22,7 @@ from dynamo.llm import ( ...@@ -22,13 +22,7 @@ from dynamo.llm import (
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from .args import ( from .args import ENABLE_LMCACHE, Config, configure_ports, overwrite_args, parse_args
ENABLE_LMCACHE,
Config,
configure_ports_with_etcd,
overwrite_args,
parse_args,
)
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import VllmHealthCheckPayload from .health_check import VllmHealthCheckPayload
from .publisher import StatLoggerFactory from .publisher import StatLoggerFactory
...@@ -69,8 +63,7 @@ async def graceful_shutdown(runtime): ...@@ -69,8 +63,7 @@ async def graceful_shutdown(runtime):
async def worker(runtime: DistributedRuntime): async def worker(runtime: DistributedRuntime):
config = parse_args() config = parse_args()
etcd_client = runtime.do_not_use_etcd_client() await configure_ports(runtime, config)
await configure_ports_with_etcd(config, etcd_client)
overwrite_args(config) overwrite_args(config)
# Set up signal handler for graceful shutdown # Set up signal handler for graceful shutdown
...@@ -208,7 +201,7 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -208,7 +201,7 @@ async def init(runtime: DistributedRuntime, config: Config):
config, factory config, factory
) )
# TODO Hack to get data, move this to registering in ETCD # TODO Hack to get data, move this to registering in TBD
factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks) factory.set_num_gpu_blocks_all(vllm_config.cache_config.num_gpu_blocks)
factory.set_request_total_slots_all(vllm_config.scheduler_config.max_num_seqs) factory.set_request_total_slots_all(vllm_config.scheduler_config.max_num_seqs)
factory.init_publish() factory.init_publish()
......
...@@ -3,17 +3,14 @@ ...@@ -3,17 +3,14 @@
"""Port allocation and management utilities for Dynamo services.""" """Port allocation and management utilities for Dynamo services."""
import asyncio
import json import json
import logging import logging
import os import os
import random
import socket import socket
import time import time
from contextlib import contextmanager from dataclasses import dataclass
from dataclasses import dataclass, field
from dynamo.runtime import EtcdKvCache from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -40,77 +37,31 @@ class DynamoPortRange: ...@@ -40,77 +37,31 @@ class DynamoPortRange:
) )
@dataclass
class EtcdContext:
"""Context for ETCD operations"""
client: EtcdKvCache # etcd client instance
namespace: str # Namespace for keys (used in key prefix)
def make_port_key(self, port: int) -> str:
"""Generate ETCD key for a port reservation"""
node_ip = get_host_ip()
return f"dyn://{self.namespace}/ports/{node_ip}/{port}"
@dataclass @dataclass
class PortMetadata: class PortMetadata:
"""Metadata to store with port reservations in ETCD""" """Metadata to store with port reservations"""
worker_id: str # Worker identifier (e.g., "vllm-backend-dp0") worker_id: str # Worker identifier (e.g., "vllm-backend-dp0")
reason: str # Purpose of the port (e.g., "nixl_side_channel_port") reason: str # Purpose of the port (e.g., "nixl_side_channel_port")
block_info: dict = field(default_factory=dict) # Optional block allocation info
def to_etcd_value(self) -> dict:
"""Convert to dictionary for ETCD storage"""
value = {
"worker_id": self.worker_id,
"reason": self.reason,
"reserved_at": time.time(),
"pid": os.getpid(),
}
if self.block_info:
value.update(self.block_info)
return value
@dataclass @dataclass
class PortAllocationRequest: class PortAllocationRequest:
"""Parameters for port allocation""" """Parameters for port allocation"""
etcd_context: EtcdContext
metadata: PortMetadata metadata: PortMetadata
port_range: DynamoPortRange port_range: DynamoPortRange
block_size: int = 1 block_size: int = 1
max_attempts: int = 100
@contextmanager def __post_init__(self):
def hold_ports(ports: int | list[int]): if self.block_size < 1:
"""Context manager to hold port binding(s). raise ValueError("block_size must be >= 1")
range_len = self.port_range.max - self.port_range.min + 1
Holds socket bindings to ensure exclusive access to ports during reservation. if self.block_size > range_len:
Can handle a single port or multiple ports. raise ValueError(
f"block_size {self.block_size} exceeds range length {range_len} "
Args: f"({self.port_range.min}-{self.port_range.max})"
ports: Single port number or list of port numbers to hold )
"""
if isinstance(ports, int):
ports = [ports]
sockets = []
try:
for port in ports:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("", port))
sockets.append(sock)
yield
finally:
for sock in sockets:
sock.close()
def check_port_available(port: int) -> bool: def check_port_available(port: int) -> bool:
...@@ -123,140 +74,63 @@ def check_port_available(port: int) -> bool: ...@@ -123,140 +74,63 @@ def check_port_available(port: int) -> bool:
return False return False
async def reserve_port_in_etcd( async def allocate_and_reserve_port_block(
etcd_context: EtcdContext, runtime: DistributedRuntime, namespace: str, request: PortAllocationRequest
port: int, ) -> list[int]:
metadata: PortMetadata,
) -> None:
"""Reserve a single port in ETCD."""
key = etcd_context.make_port_key(port)
value = metadata.to_etcd_value()
await etcd_context.client.kv_create(
key=key,
value=json.dumps(value).encode(),
lease_id=etcd_context.client.primary_lease_id(),
)
async def allocate_and_reserve_port_block(request: PortAllocationRequest) -> list[int]:
""" """
Allocate a contiguous block of ports from the specified range and atomically reserve them in ETCD. Allocate a contiguous block of ports from the specified range and atomically reserve them.
Returns a list of all allocated ports in order. Returns a list of all allocated ports in order.
This function uses a context manager to hold port bindings while reserving in ETCD,
preventing race conditions between multiple processes.
Args: Args:
request: PortAllocationRequest containing all allocation parameters request: PortAllocationRequest containing all allocation parameters
Returns: Returns:
list[int]: List of all allocated ports in ascending order list[int]: List of all allocated ports in ascending order
Raises:
RuntimeError: If unable to reserve a port block within max_attempts
OSError: If unable to create sockets (system resource issues)
""" """
# Create a list of valid starting ports (must have room for the entire block) # Create a list of valid starting ports (must have room for the entire block)
max_start_port = request.port_range.max - request.block_size + 1
if max_start_port < request.port_range.min:
raise ValueError(
f"Port range {request.port_range.min}-{request.port_range.max} is too small for block size {request.block_size}"
)
available_start_ports = list(range(request.port_range.min, max_start_port + 1))
random.shuffle(available_start_ports)
actual_max_attempts = min(len(available_start_ports), request.max_attempts)
for attempt in range(1, actual_max_attempts + 1):
start_port = available_start_ports[attempt - 1]
ports_to_reserve = list(range(start_port, start_port + request.block_size))
try:
# Try to bind to all ports in the block atomically
with hold_ports(ports_to_reserve):
logger.debug(
f"Successfully bound to ports {ports_to_reserve}, now reserving in ETCD"
)
# We have exclusive access to these ports, now reserve them in ETCD context_json = {
for i, port in enumerate(ports_to_reserve): "worker_id": str(request.metadata.worker_id),
port_metadata = PortMetadata( "reason": request.metadata.reason,
worker_id=f"{request.metadata.worker_id}-{i}" "reserved_at": time.time(),
if request.block_size > 1 "pid": os.getpid(),
else request.metadata.worker_id,
reason=request.metadata.reason,
block_info={
"block_index": i,
"block_size": request.block_size, "block_size": request.block_size,
"block_start": start_port,
} }
if request.block_size > 1
else {},
)
await reserve_port_in_etcd(
etcd_context=request.etcd_context,
port=port,
metadata=port_metadata,
)
logger.debug(
f"Reserved port block {ports_to_reserve} from range {request.port_range.min}-{request.port_range.max} "
f"for {request.metadata.worker_id} (block_size={request.block_size})"
)
return ports_to_reserve
except OSError as e:
logger.debug(
f"Failed to bind to port block starting at {start_port} (attempt {attempt}): {e}"
)
except Exception as e:
logger.debug(
f"Failed to reserve port block starting at {start_port} in ETCD (attempt {attempt}): {e}"
)
if attempt < actual_max_attempts:
await asyncio.sleep(0.01)
raise RuntimeError( return await runtime.allocate_port_block(
f"Failed to allocate and reserve a port block of size {request.block_size} from range " namespace,
f"{request.port_range.min}-{request.port_range.max} after {actual_max_attempts} attempts" request.port_range.min,
request.port_range.max,
request.block_size,
json.dumps(context_json),
) )
async def allocate_and_reserve_port( async def allocate_and_reserve_port(
etcd_context: EtcdContext, runtime: DistributedRuntime,
namespace: str,
metadata: PortMetadata, metadata: PortMetadata,
port_range: DynamoPortRange, port_range: DynamoPortRange,
max_attempts: int = 100,
) -> int: ) -> int:
""" """
Allocate a port from the specified range and atomically reserve it in ETCD. Allocate a port from the specified range and atomically reserve it.
This is a convenience wrapper around allocate_and_reserve_port_block with block_size=1. This is a convenience wrapper around allocate_and_reserve_port_block with block_size=1.
Args: Args:
etcd_context: ETCD context for operations metadata: Port metadata / context
metadata: Port metadata for ETCD storage
port_range: DynamoPortRange object specifying min and max ports to try port_range: DynamoPortRange object specifying min and max ports to try
max_attempts: Maximum number of ports to try (default: 100)
Returns: Returns:
int: The allocated port number int: The allocated port number
Raises:
RuntimeError: If unable to reserve a port within max_attempts
OSError: If unable to create sockets (system resource issues)
""" """
request = PortAllocationRequest( request = PortAllocationRequest(
etcd_context=etcd_context,
metadata=metadata, metadata=metadata,
port_range=port_range, port_range=port_range,
block_size=1, block_size=1,
max_attempts=max_attempts,
) )
allocated_ports = await allocate_and_reserve_port_block(request) allocated_ports = await allocate_and_reserve_port_block(runtime, namespace, request)
if not allocated_ports:
raise RuntimeError("Failed to allocate required ports")
return allocated_ports[0] # Return the single allocated port return allocated_ports[0] # Return the single allocated port
......
...@@ -50,11 +50,11 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -50,11 +50,11 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
self.num_gpu_block = 1 self.num_gpu_block = 1
self.request_total_slots = 1 self.request_total_slots = 1
# TODO: Remove this and pass as metadata through etcd # TODO: Remove this and pass as metadata through shared storage
def set_num_gpu_block(self, num_blocks): def set_num_gpu_block(self, num_blocks):
self.num_gpu_block = num_blocks self.num_gpu_block = num_blocks
# TODO: Remove this and pass as metadata through etcd # TODO: Remove this and pass as metadata through shared storage
def set_num_request_total_slots(self, request_total_slots): def set_num_request_total_slots(self, request_total_slots):
self.request_total_slots = request_total_slots self.request_total_slots = request_total_slots
...@@ -66,7 +66,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -66,7 +66,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
): ):
# request_total_slots and kv_total_blocks are properties of model + gpu # request_total_slots and kv_total_blocks are properties of model + gpu
# we should only publish them once, not every metric update # we should only publish them once, not every metric update
# they should be part of some runtime metadata tied to MDC or put in etcd ? # they should be part of some runtime metadata tied to MDC or put in shared storage ?
hit_rate = 0 hit_rate = 0
if scheduler_stats.prefix_cache_stats.queries > 0: if scheduler_stats.prefix_cache_stats.queries > 0:
hit_rate = ( hit_rate = (
...@@ -160,7 +160,7 @@ class StatLoggerFactory: ...@@ -160,7 +160,7 @@ class StatLoggerFactory:
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase: def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
return self.create_stat_logger(dp_rank=dp_rank) return self.create_stat_logger(dp_rank=dp_rank)
# TODO Remove once we publish metadata to etcd # TODO Remove once we publish metadata to shared storage
def set_num_gpu_blocks_all(self, num_blocks): def set_num_gpu_blocks_all(self, num_blocks):
if self.created_logger: if self.created_logger:
self.created_logger.set_num_gpu_block(num_blocks) self.created_logger.set_num_gpu_block(num_blocks)
......
...@@ -28,7 +28,7 @@ from publisher import StatLoggerFactory ...@@ -28,7 +28,7 @@ from publisher import StatLoggerFactory
from utils.args import ( from utils.args import (
Config, Config,
base_parse_args, base_parse_args,
configure_ports_with_etcd, configure_ports,
overwrite_args, overwrite_args,
parse_endpoint, parse_endpoint,
) )
...@@ -420,8 +420,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -420,8 +420,7 @@ async def worker(runtime: DistributedRuntime):
args, config = VllmBaseWorker.parse_args() args, config = VllmBaseWorker.parse_args()
# vLLM config overwrites # vLLM config overwrites
etcd_client = runtime.do_not_use_etcd_client() await configure_ports(runtime, config)
await configure_ports_with_etcd(config, etcd_client)
overwrite_args(config) overwrite_args(config)
await init(runtime, args, config) await init(runtime, args, config)
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import asyncio
import json import json
import logging import logging
import os import os
...@@ -27,6 +14,8 @@ from vllm.config import KVTransferConfig ...@@ -27,6 +14,8 @@ from vllm.config import KVTransferConfig
from vllm.distributed.kv_events import KVEventsConfig from vllm.distributed.kv_events import KVEventsConfig
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from dynamo.runtime import DistributedRuntime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo") DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
...@@ -127,66 +116,43 @@ def base_parse_args( ...@@ -127,66 +116,43 @@ def base_parse_args(
async def allocate_and_reserve_port( async def allocate_and_reserve_port(
namespace, runtime: DistributedRuntime,
etcd_client, namespace: str,
worker_id: str, worker_id: str,
reason: str, reason: str,
max_attempts: int = 100,
) -> int: ) -> int:
""" """
Get an OS-assigned port and atomically reserve it in ETCD. Get an OS-assigned port and atomically reserve it.
Retries until successful or max_attempts reached. Retries until successful or internal max attempts reached.
Args:
max_attempts: Maximum number of ports to try (default: 100)
Raises:
RuntimeError: If unable to reserve a port within max_attempts
OSError: If unable to create sockets (system resource issues)
""" """
node_name = socket.gethostname() context_json = {
for attempt in range(1, max_attempts + 1):
# Hold socket open just long enough to reserve in ETCD
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("", 0))
port = sock.getsockname()[1]
# Reserve in ETCD while holding the socket
key = f"dyn://{namespace}/ports/{node_name}/{port}"
value = {
"worker_id": worker_id, "worker_id": worker_id,
"reason": reason, "reason": reason,
"reserved_at": time.time(), "reserved_at": time.time(),
"pid": os.getpid(), "pid": os.getpid(),
"block_size": 1,
} }
try: # Any ephemeral port, equivalent to binding port 0
await etcd_client.kv_create( port_range_min = 32_768
key=key, port_range_max = 60_999
value=json.dumps(value).encode(), allocated_ports = await runtime.allocate_port_block(
lease_id=etcd_client.primary_lease_id(), namespace,
port_range_min,
port_range_max,
1, # how many ports to allocate
json.dumps(context_json),
) )
if not allocated_ports:
raise RuntimeError("allocate_port_block returned no ports")
port = allocated_ports[0]
logger.debug(f"Reserved OS-assigned port {port} for {worker_id}") logger.debug(f"Reserved OS-assigned port {port} for {worker_id}")
return port return port
except Exception as e:
logger.debug(
f"Port {port} on {node_name} was already reserved (attempt {attempt}): {e}"
)
if attempt < max_attempts:
await asyncio.sleep(0.01)
raise RuntimeError(
f"Failed to allocate and reserve a port after {max_attempts} attempts"
)
async def configure_ports(runtime: DistributedRuntime, config: Config):
async def configure_ports_with_etcd(config: Config, etcd_client): """Configure including port allocation and vLLM overrides."""
"""Configure all settings that require ETCD, including port allocation and vLLM overrides."""
# First, allocate ports # First, allocate ports
dp_rank = config.engine_args.data_parallel_rank or 0 dp_rank = config.engine_args.data_parallel_rank or 0
...@@ -194,16 +160,16 @@ async def configure_ports_with_etcd(config: Config, etcd_client): ...@@ -194,16 +160,16 @@ async def configure_ports_with_etcd(config: Config, etcd_client):
# Allocate KV events port # Allocate KV events port
kv_port = await allocate_and_reserve_port( kv_port = await allocate_and_reserve_port(
runtime=runtime,
namespace=config.namespace, namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}", worker_id=f"{worker_id}",
reason="zmq_kv_event_port", reason="zmq_kv_event_port",
) )
# Allocate side channel port # Allocate side channel port
side_channel_port = await allocate_and_reserve_port( side_channel_port = await allocate_and_reserve_port(
runtime=runtime,
namespace=config.namespace, namespace=config.namespace,
etcd_client=etcd_client,
worker_id=f"{worker_id}", worker_id=f"{worker_id}",
reason="nixl_side_channel_port", reason="nixl_side_channel_port",
) )
...@@ -215,12 +181,10 @@ async def configure_ports_with_etcd(config: Config, etcd_client): ...@@ -215,12 +181,10 @@ async def configure_ports_with_etcd(config: Config, etcd_client):
def overwrite_args(config): def overwrite_args(config):
"""Set vLLM defaults for Dynamo.""" """Set vLLM defaults for Dynamo."""
assert ( assert config.kv_port is not None, "Must set the kv_port, use configure_ports"
config.kv_port is not None
), "Must set the kv_port, use configure_ports_with_etcd"
assert ( assert (
config.side_channel_port is not None config.side_channel_port is not None
), "Must set the side_channel_port, use configure_ports_with_etcd" ), "Must set the side_channel_port, use configure_ports"
dp_rank = config.engine_args.data_parallel_rank or 0 dp_rank = config.engine_args.data_parallel_rank or 0
......
...@@ -1490,14 +1490,17 @@ dependencies = [ ...@@ -1490,14 +1490,17 @@ dependencies = [
"dynamo-runtime", "dynamo-runtime",
"either", "either",
"futures", "futures",
"local-ip-address",
"once_cell", "once_cell",
"prometheus", "prometheus",
"pyo3", "pyo3",
"pyo3-async-runtimes", "pyo3-async-runtimes",
"pythonize", "pythonize",
"rand 0.9.2",
"rstest", "rstest",
"serde", "serde",
"serde_json", "serde_json",
"socket2 0.6.0",
"thiserror 2.0.16", "thiserror 2.0.16",
"tokio", "tokio",
"tokio-stream", "tokio-stream",
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[workspace] [workspace]
# empty workspace to exclude from top level workspace # empty workspace to exclude from top level workspace
...@@ -49,7 +37,10 @@ async-trait = { version = "0.1" } ...@@ -49,7 +37,10 @@ async-trait = { version = "0.1" }
derive-getters = "0.5" derive-getters = "0.5"
either = { version = "1.13", features = ["serde"] } either = { version = "1.13", features = ["serde"] }
futures = { version = "0.3" } futures = { version = "0.3" }
local-ip-address = { version = "0.6" }
once_cell = { version = "1.20.3" } once_cell = { version = "1.20.3" }
rand = { version = "0.9" }
socket2 = { version = "0.6" }
serde = { version = "1" } serde = { version = "1" }
serde_json = { version = "1.0.138" } serde_json = { version = "1.0.138" }
thiserror = { version = "2.0" } thiserror = { version = "2.0" }
......
...@@ -8,8 +8,11 @@ use pyo3::types::PyBytes; ...@@ -8,8 +8,11 @@ use pyo3::types::PyBytes;
use pyo3::types::{PyDict, PyList, PyString}; use pyo3::types::{PyDict, PyList, PyString};
use pyo3::IntoPyObjectExt; use pyo3::IntoPyObjectExt;
use pyo3::{exceptions::PyException, prelude::*}; use pyo3::{exceptions::PyException, prelude::*};
use rand::seq::IteratorRandom as _;
use rs::pipeline::network::Ingress; use rs::pipeline::network::Ingress;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
use std::path::PathBuf; use std::path::PathBuf;
use std::time::Duration;
use std::{fmt::Display, sync::Arc}; use std::{fmt::Display, sync::Arc};
use tokio::sync::Mutex; use tokio::sync::Mutex;
...@@ -374,6 +377,137 @@ impl DistributedRuntime { ...@@ -374,6 +377,137 @@ impl DistributedRuntime {
}) })
} }
/// Allocate a contiguous block of ports from the specified range and atomically reserve them.
/// Returns a list of all allocated ports in order.
#[pyo3(signature = (namespace, port_min, port_max, block_size, context=None))]
fn allocate_port_block<'p>(
&self,
py: Python<'p>,
namespace: &str,
port_min: u16,
port_max: u16,
block_size: u16,
context: Option<String>, // Optional info to store alongside the reservation
) -> PyResult<Bound<'p, PyAny>> {
const MAX_ALLOCATE_ATTEMPTS: usize = 100;
if block_size == 0 {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Block size must be at least 1",
));
}
let Some(etcd_client) = self.inner.etcd_client() else {
return Err(PyErr::new::<PyException, _>(
"Static workers should not need to reserve ports",
));
};
let min = port_min;
let max = port_max;
// Compute maximum valid starting port (inclusive)
let max_start_port = max.saturating_sub(block_size.saturating_sub(1));
if max_start_port < min {
return Err(PyErr::new::<PyException, _>(format!(
"Port range {min}-{max} is too small for block size {block_size}",
)));
}
// Randomize candidate starting ports to reduce contention/races
let candidate_count =
(max_start_port - port_min + 1).min(MAX_ALLOCATE_ATTEMPTS as u16) as usize;
let mut rng = rand::rng();
let candidate_ports: Vec<u16> =
(port_min..=max_start_port).choose_multiple(&mut rng, candidate_count);
let local_ip = match local_ip() {
Ok(ip) => ip,
Err(err) => {
return Err(PyErr::new::<PyException, _>(format!(
"Failed fetching local IP address: {err}"
)));
}
};
let context_bytes = context.map(|s| s.as_bytes().to_vec()).unwrap_or_default();
let namespace = namespace.to_owned();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
for (attempt_idx, start_port) in candidate_ports.into_iter().enumerate() {
let end_port_exclusive = start_port + block_size;
let ports_to_reserve: Vec<u16> = (start_port..end_port_exclusive).collect();
// Hold/bind all ports in the block
let mut sockets = Vec::with_capacity(ports_to_reserve.len());
let mut bind_failed = false;
for &port in &ports_to_reserve {
match bind_tcp_port(port) {
Ok(sock) => sockets.push(sock),
Err(e) => {
tracing::error!(
"Failed to bind to port block starting at {start_port} (attempt {}): {e}",
attempt_idx + 1,
);
bind_failed = true;
break;
}
}
}
if bind_failed {
// Let previously bound sockets drop here
if attempt_idx < candidate_count - 1 {
tokio::time::sleep(Duration::from_millis(10)).await;
}
continue;
}
// With sockets held, reserve in ETCD
let mut reserved_keys = Vec::with_capacity(ports_to_reserve.len());
let mut reservation_failed = false;
for port in &ports_to_reserve {
let key = make_port_key(&namespace, local_ip, *port).map_err(to_pyerr)?;
if let Err(e) = etcd_client
.kv_create(&key, context_bytes.clone(), None)
.await
{
tracing::error!(
"Failed to reserve port block starting at {start_port} (attempt {}): {e}",
attempt_idx + 1,
);
reservation_failed = true;
break;
}
reserved_keys.push(key);
}
if reservation_failed {
// Cleanup partial reservations
for key in reserved_keys {
if let Err(e) = etcd_client.kv_delete(key.as_str(), None).await {
tracing::warn!("Failed to cleanup reserved port {key}: {e}");
}
}
// Sockets automatically released via RAII
if attempt_idx < candidate_count - 1 {
tokio::time::sleep(Duration::from_millis(10)).await;
}
continue;
}
// Success - sockets will be released automatically
tracing::debug!("Reserved port block {ports_to_reserve:?}");
return Ok(ports_to_reserve);
}
Err(PyErr::new::<PyException, _>(format!(
"Failed to allocate and reserve a port block of size {block_size} from range {min}-{max} after {candidate_count} attempts")))
})
}
fn do_not_use_etcd_client(&self) -> PyResult<Option<EtcdClient>> { fn do_not_use_etcd_client(&self) -> PyResult<Option<EtcdClient>> {
match self.inner.etcd_client().clone() { match self.inner.etcd_client().clone() {
Some(etcd_client) => Ok(Some(EtcdClient { inner: etcd_client })), Some(etcd_client) => Ok(Some(EtcdClient { inner: etcd_client })),
...@@ -390,6 +524,33 @@ impl DistributedRuntime { ...@@ -390,6 +524,33 @@ impl DistributedRuntime {
} }
} }
// Bind a TCP port and return a socket held until dropped.
fn bind_tcp_port(port: u16) -> std::io::Result<socket2::Socket> {
let sock = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
sock.set_reuse_address(true)?;
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port));
sock.bind(&addr.into())?;
Ok(sock)
}
fn make_port_key(namespace: &str, node_ip: IpAddr, port: u16) -> anyhow::Result<String> {
Ok(format!("dyn://{namespace}/ports/{node_ip}/{port}"))
}
fn local_ip() -> Result<IpAddr, local_ip_address::Error> {
local_ip_address::local_ip().or_else(|err| match err {
local_ip_address::Error::LocalIpAddressNotFound => {
// Fall back to IPv6 if no IPv4 addresses are found
local_ip_address::local_ipv6()
}
_ => Err(err),
})
}
#[pymethods] #[pymethods]
impl EtcdKvCache { impl EtcdKvCache {
#[new] #[new]
......
...@@ -48,6 +48,13 @@ class DistributedRuntime: ...@@ -48,6 +48,13 @@ class DistributedRuntime:
""" """
... ...
def allocate_port_block(self, namespace, port_min, port_max, block_size, context=None) -> List[int]:
"""
Allocate a contiguous block of ports from the specified range and atomically reserve them.
Returns a list of all allocated ports in order.
"""
...
def shutdown(self) -> None: def shutdown(self) -> None:
""" """
Shutdown the runtime by triggering the cancellation token Shutdown the runtime by triggering the cancellation token
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
from dynamo._core import DistributedRuntime
# Todo add support for launching etcd
# pytestmark = pytest.mark.pre_merge
async def test_simple_put_get():
# Initialize runtime
loop = asyncio.get_running_loop()
runtime = DistributedRuntime(loop, False)
# Get etcd client
etcd = runtime.do_not_use_etcd_client()
# Write some key-value pairs
test_keys = {
"test/key1": b"value1",
"test/key2": b"value2",
"test/nested/key3": b"value3",
}
# Write each key-value pair
for key, value in test_keys.items():
print(f"Writing {key} = {value!r}")
await etcd.kv_create_or_validate(key, value, None)
print("Successfully wrote all keys to etcd")
# Test kv_put
put_key = "test/put_key"
put_value = b"put_value"
test_keys[put_key] = put_value
print(f"Using kv_put to write {put_key} = {put_value!r}")
await etcd.kv_put(put_key, put_value, None)
# Test kv_get_prefix to read all keys
print("\nReading all keys with prefix 'test/':")
keys_values = await etcd.kv_get_prefix("test/")
for item in keys_values:
print(f"Retrieved {item['key']} = {item['value']!r}")
assert test_keys[item["key"]] == item["value"]
# Verify prefix filtering works
print("\nReading keys with prefix 'test/nested/':")
nested_keys_values = await etcd.kv_get_prefix("test/nested/")
for item in nested_keys_values:
print(f"Retrieved {item['key']} = {item['value']!r}")
assert test_keys[item["key"]] == item["value"]
# Shutdown runtime
runtime.shutdown()
if __name__ == "__main__":
asyncio.run(test_simple_put_get())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment