Unverified Commit ee3de9e4 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: remove localhost fallback from NIXL side-channel host detection (#6539)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 149bf5a2
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import ipaddress
import json
import logging
import os
......@@ -459,32 +460,116 @@ def _reject_connector_flag(dynamo_config: Config) -> None:
def get_host_ip() -> str:
"""Get the IP address of the host for side-channel coordination."""
try:
host_name = socket.gethostname()
except socket.error as exc:
logger.warning("Failed to get hostname: %s, falling back to 127.0.0.1", exc)
return "127.0.0.1"
"""Get a routable IP address of the host for NIXL side-channel coordination.
try:
host_ip = socket.gethostbyname(host_name)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((host_ip, 0))
Tries multiple strategies to find a usable (non-loopback, non-link-local) IP:
1. Resolve hostname via DNS (tries IPv4 first, then IPv6)
2. UDP connect trick (finds the default outbound interface IP; IPv4, then IPv6)
On multi-NIC clusters (e.g. SLURM with InfiniBand), auto-detection picks
the default egress interface which may not be correct. Set
VLLM_NIXL_SIDE_CHANNEL_HOST explicitly in those environments.
Raises:
RuntimeError: If no usable IP can be determined.
"""
# Strategy 1: hostname resolution (IPv4 first, then IPv6)
host_ip = _try_hostname_resolution()
if host_ip and _is_routable(host_ip):
logger.info(
"NIXL side-channel host determined via hostname resolution: %s",
host_ip,
)
return host_ip
except socket.gaierror as exc:
logger.warning(
"Hostname %s cannot be resolved: %s, falling back to 127.0.0.1",
host_name,
exc,
# Strategy 2: UDP connect trick — finds the IP of the interface
# that would route to an external address (no data is sent).
# Try IPv4 first, then IPv6.
host_ip = _try_udp_connect(socket.AF_INET, ("8.8.8.8", 80))
if host_ip and _is_routable(host_ip):
logger.info(
"NIXL side-channel host determined via outbound interface detection (IPv4): %s",
host_ip,
)
return "127.0.0.1"
except socket.error as exc:
logger.warning(
"Hostname %s is not usable for binding: %s, falling back to 127.0.0.1",
host_name,
exc,
return host_ip
host_ip = _try_udp_connect(socket.AF_INET6, ("2001:4860:4860::8888", 80))
if host_ip and _is_routable(host_ip):
logger.info(
"NIXL side-channel host determined via outbound interface detection (IPv6): %s",
host_ip,
)
return host_ip
raise RuntimeError(
"Unable to determine a routable host IP for NIXL side-channel. "
"Hostname resolution and outbound interface detection both failed or "
"returned a non-routable address (loopback, link-local, etc.). "
"Please set the VLLM_NIXL_SIDE_CHANNEL_HOST environment variable to "
"the IP address that peer nodes can reach this host on."
)
def _is_routable(ip_str: str) -> bool:
"""Return True if the IP is usable for cross-node communication.
Rejects loopback (127.x / ::1), link-local (169.254.x / fe80::),
unspecified (0.0.0.0 / ::), and multicast addresses.
RFC1918 private addresses (10.x, 172.16-31.x, 192.168.x) are allowed.
"""
try:
addr = ipaddress.ip_address(ip_str)
return not (
addr.is_loopback
or addr.is_link_local
or addr.is_unspecified
or addr.is_multicast
)
return "127.0.0.1"
except ValueError:
return False
def _try_hostname_resolution() -> str | None:
"""Resolve hostname to a routable, bindable IP.
Uses getaddrinfo with AF_UNSPEC to support both IPv4 and IPv6.
Returns the first routable and bindable address, or None on failure.
"""
try:
host_name = socket.gethostname()
infos = socket.getaddrinfo(
host_name, None, socket.AF_UNSPEC, socket.SOCK_STREAM
)
for family, socktype, _, _, sockaddr in infos:
host_ip = sockaddr[0]
if not _is_routable(host_ip):
continue
try:
with socket.socket(family, socktype) as s:
s.bind((host_ip, 0))
return host_ip
except OSError:
continue
return None
except OSError as exc:
logger.debug("Hostname resolution failed: %s", exc)
return None
def _try_udp_connect(family: socket.AddressFamily, target: tuple) -> str | None:
"""Use UDP connect to find the outbound interface IP. Returns None on failure.
Args:
family: socket.AF_INET or socket.AF_INET6
target: (address, port) tuple to "connect" to (no data is sent)
"""
try:
with socket.socket(family, socket.SOCK_DGRAM) as s:
s.connect(target)
return s.getsockname()[0]
except OSError as exc:
logger.debug("UDP connect detection failed (family=%s): %s", family, exc)
return None
def ensure_side_channel_host():
......@@ -492,11 +577,9 @@ def ensure_side_channel_host():
existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST")
if existing_host:
logger.debug(
"Preserving existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host
)
logger.info("Using existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host)
return
host_ip = get_host_ip()
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
logger.debug("Set VLLM_NIXL_SIDE_CHANNEL_HOST to %s", host_ip)
logger.info("Set VLLM_NIXL_SIDE_CHANNEL_HOST to %s (auto-detected)", host_ip)
......@@ -5,16 +5,21 @@
import json
import re
import socket
import warnings
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from dynamo.vllm.args import (
_connector_to_kv_transfer_json,
_is_routable,
_uses_dynamo_connector,
_uses_nixl_connector,
ensure_side_channel_host,
get_host_ip,
parse_args,
)
from dynamo.vllm.constants import DisaggregationMode
......@@ -422,3 +427,134 @@ def test_explicit_default_mode_with_legacy_flag_raises(mock_vllm_cli):
)
with pytest.raises(ValueError, match="Cannot combine"):
parse_args()
# --- _is_routable tests (pure logic, no mocking) ---
class TestIsRoutable:
def test_accepts_private_ipv4(self):
assert _is_routable("10.0.0.5") is True
assert _is_routable("192.168.1.1") is True
def test_accepts_private_ipv6(self):
assert _is_routable("fd00::1") is True
def test_rejects_loopback_v4(self):
assert _is_routable("127.0.0.1") is False
def test_rejects_loopback_v6(self):
assert _is_routable("::1") is False
def test_rejects_link_local_v4(self):
assert _is_routable("169.254.1.1") is False
def test_rejects_link_local_v6(self):
assert _is_routable("fe80::1") is False
def test_rejects_unspecified(self):
assert _is_routable("0.0.0.0") is False
assert _is_routable("::") is False
def test_rejects_multicast(self):
assert _is_routable("224.0.0.1") is False
def test_rejects_invalid(self):
assert _is_routable("not-an-ip") is False
# --- get_host_ip tests (mock socket module functions) ---
class TestGetHostIp:
def test_hostname_resolution_success(self):
"""getaddrinfo returns routable IPv4 → returns it."""
with patch(
"dynamo.vllm.args._try_hostname_resolution", return_value="10.0.0.5"
):
result = get_host_ip()
assert result == "10.0.0.5"
def test_hostname_loopback_falls_through_to_udp(self):
"""getaddrinfo returns 127.0.0.1, UDP returns 10.0.0.5 → returns 10.0.0.5."""
with (
patch(
"dynamo.vllm.args._try_hostname_resolution", return_value="127.0.0.1"
),
patch("dynamo.vllm.args._try_udp_connect") as mock_udp,
):
mock_udp.side_effect = lambda family, target: (
"10.0.0.5" if family == socket.AF_INET else None
)
result = get_host_ip()
assert result == "10.0.0.5"
def test_hostname_link_local_falls_through_to_udp(self):
"""getaddrinfo returns 169.254.1.1, UDP returns 10.0.0.5 → returns 10.0.0.5."""
with (
patch(
"dynamo.vllm.args._try_hostname_resolution", return_value="169.254.1.1"
),
patch("dynamo.vllm.args._try_udp_connect") as mock_udp,
):
mock_udp.side_effect = lambda family, target: (
"10.0.0.5" if family == socket.AF_INET else None
)
result = get_host_ip()
assert result == "10.0.0.5"
def test_ipv6_fallback(self):
"""IPv4 strategies fail, IPv6 UDP returns fd00::1 → returns fd00::1."""
with (
patch("dynamo.vllm.args._try_hostname_resolution", return_value=None),
patch("dynamo.vllm.args._try_udp_connect") as mock_udp,
):
mock_udp.side_effect = lambda family, target: (
"fd00::1" if family == socket.AF_INET6 else None
)
result = get_host_ip()
assert result == "fd00::1"
def test_all_fail_raises_runtime_error(self):
"""All strategies fail → RuntimeError with VLLM_NIXL_SIDE_CHANNEL_HOST in message."""
with (
patch("dynamo.vllm.args._try_hostname_resolution", return_value=None),
patch("dynamo.vllm.args._try_udp_connect", return_value=None),
):
with pytest.raises(RuntimeError, match="VLLM_NIXL_SIDE_CHANNEL_HOST"):
get_host_ip()
# --- ensure_side_channel_host tests ---
class TestEnsureSideChannelHost:
def test_preserves_existing_env_var(self, monkeypatch):
"""Pre-set env var → verify not overwritten."""
monkeypatch.setenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "192.168.99.99")
with patch("dynamo.vllm.args.get_host_ip") as mock_get:
ensure_side_channel_host()
mock_get.assert_not_called()
import os
assert os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] == "192.168.99.99"
def test_sets_env_var_on_successful_detection(self, monkeypatch):
"""No env var set, successful detection populates the side-channel host."""
monkeypatch.delenv("VLLM_NIXL_SIDE_CHANNEL_HOST", raising=False)
with patch("dynamo.vllm.args.get_host_ip", return_value="10.0.0.5"):
ensure_side_channel_host()
import os
assert os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] == "10.0.0.5"
def test_raises_when_detection_fails_and_no_env(self, monkeypatch):
"""All strategies fail, no env var → RuntimeError."""
monkeypatch.delenv("VLLM_NIXL_SIDE_CHANNEL_HOST", raising=False)
with patch(
"dynamo.vllm.args.get_host_ip",
side_effect=RuntimeError("Unable to determine"),
):
with pytest.raises(RuntimeError, match="Unable to determine"):
ensure_side_channel_host()
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import ipaddress
import logging
import os
import socket
......@@ -116,24 +117,84 @@ def get_kv_port() -> int:
def ensure_side_channel_host():
"""Ensure the NIXL side-channel host is available without overriding user settings."""
"""Ensure the NIXL side-channel host is available without overriding user settings.
Uses hostname resolution with UDP connect fallback. Supports IPv4 and IPv6.
Raises RuntimeError if no routable IP can be determined.
"""
existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST")
if existing_host:
logger.debug(
"Preserving existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host
)
logger.info("Using existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host)
return
def is_routable(ip_str: str) -> bool:
try:
addr = ipaddress.ip_address(ip_str)
return not (
addr.is_loopback
or addr.is_link_local
or addr.is_unspecified
or addr.is_multicast
)
except ValueError:
return False
# Strategy 1: hostname resolution (AF_UNSPEC for IPv4+IPv6)
host_ip = None
detection_method = None
try:
host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket:
test_socket.bind((host_ip, 0))
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
logger.debug("Set VLLM_NIXL_SIDE_CHANNEL_HOST to %s", host_ip)
except (socket.error, socket.gaierror):
logger.warning("Failed to get hostname, falling back to 127.0.0.1")
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = "127.0.0.1"
infos = socket.getaddrinfo(
host_name, None, socket.AF_UNSPEC, socket.SOCK_STREAM
)
for family, socktype, _, _, sockaddr in infos:
candidate = sockaddr[0]
try:
with socket.socket(family, socktype) as s:
s.bind((candidate, 0))
if is_routable(candidate):
host_ip = candidate
detection_method = "hostname resolution"
break
except OSError:
continue
except OSError as exc:
logger.debug("Hostname resolution failed: %s", exc)
# Strategy 2: UDP connect trick (IPv4 then IPv6)
if not host_ip:
for family, target, label in [
(socket.AF_INET, ("8.8.8.8", 80), "outbound interface detection (IPv4)"),
(
socket.AF_INET6,
("2001:4860:4860::8888", 80),
"outbound interface detection (IPv6)",
),
]:
try:
with socket.socket(family, socket.SOCK_DGRAM) as s:
s.connect(target)
candidate = s.getsockname()[0]
if is_routable(candidate):
host_ip = candidate
detection_method = label
break
except OSError:
continue
if not host_ip:
raise RuntimeError(
"Unable to determine a routable host IP for NIXL side-channel. "
"Please set the VLLM_NIXL_SIDE_CHANNEL_HOST environment variable to "
"the IP address that peer nodes can reach this host on."
)
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
logger.info(
"Set VLLM_NIXL_SIDE_CHANNEL_HOST=%s (detected via %s)",
host_ip,
detection_method,
)
def configure_ports(config: Config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment