Unverified Commit ee3de9e4 authored by Alec's avatar Alec Committed by GitHub
Browse files

feat: remove localhost fallback from NIXL side-channel host detection (#6539)


Signed-off-by: default avataralec-flowers <aflowers@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent 149bf5a2
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import ipaddress
import json import json
import logging import logging
import os import os
...@@ -459,32 +460,116 @@ def _reject_connector_flag(dynamo_config: Config) -> None: ...@@ -459,32 +460,116 @@ def _reject_connector_flag(dynamo_config: Config) -> None:
def get_host_ip() -> str: def get_host_ip() -> str:
"""Get the IP address of the host for side-channel coordination.""" """Get a routable IP address of the host for NIXL side-channel coordination.
try:
host_name = socket.gethostname()
except socket.error as exc:
logger.warning("Failed to get hostname: %s, falling back to 127.0.0.1", exc)
return "127.0.0.1"
try: Tries multiple strategies to find a usable (non-loopback, non-link-local) IP:
host_ip = socket.gethostbyname(host_name) 1. Resolve hostname via DNS (tries IPv4 first, then IPv6)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket: 2. UDP connect trick (finds the default outbound interface IP; IPv4, then IPv6)
test_socket.bind((host_ip, 0))
On multi-NIC clusters (e.g. SLURM with InfiniBand), auto-detection picks
the default egress interface which may not be correct. Set
VLLM_NIXL_SIDE_CHANNEL_HOST explicitly in those environments.
Raises:
RuntimeError: If no usable IP can be determined.
"""
# Strategy 1: hostname resolution (IPv4 first, then IPv6)
host_ip = _try_hostname_resolution()
if host_ip and _is_routable(host_ip):
logger.info(
"NIXL side-channel host determined via hostname resolution: %s",
host_ip,
)
return host_ip return host_ip
except socket.gaierror as exc:
logger.warning( # Strategy 2: UDP connect trick — finds the IP of the interface
"Hostname %s cannot be resolved: %s, falling back to 127.0.0.1", # that would route to an external address (no data is sent).
host_name, # Try IPv4 first, then IPv6.
exc, host_ip = _try_udp_connect(socket.AF_INET, ("8.8.8.8", 80))
if host_ip and _is_routable(host_ip):
logger.info(
"NIXL side-channel host determined via outbound interface detection (IPv4): %s",
host_ip,
) )
return "127.0.0.1" return host_ip
except socket.error as exc:
logger.warning( host_ip = _try_udp_connect(socket.AF_INET6, ("2001:4860:4860::8888", 80))
"Hostname %s is not usable for binding: %s, falling back to 127.0.0.1", if host_ip and _is_routable(host_ip):
host_name, logger.info(
exc, "NIXL side-channel host determined via outbound interface detection (IPv6): %s",
host_ip,
)
return host_ip
raise RuntimeError(
"Unable to determine a routable host IP for NIXL side-channel. "
"Hostname resolution and outbound interface detection both failed or "
"returned a non-routable address (loopback, link-local, etc.). "
"Please set the VLLM_NIXL_SIDE_CHANNEL_HOST environment variable to "
"the IP address that peer nodes can reach this host on."
)
def _is_routable(ip_str: str) -> bool:
"""Return True if the IP is usable for cross-node communication.
Rejects loopback (127.x / ::1), link-local (169.254.x / fe80::),
unspecified (0.0.0.0 / ::), and multicast addresses.
RFC1918 private addresses (10.x, 172.16-31.x, 192.168.x) are allowed.
"""
try:
addr = ipaddress.ip_address(ip_str)
return not (
addr.is_loopback
or addr.is_link_local
or addr.is_unspecified
or addr.is_multicast
) )
return "127.0.0.1" except ValueError:
return False
def _try_hostname_resolution() -> str | None:
"""Resolve hostname to a routable, bindable IP.
Uses getaddrinfo with AF_UNSPEC to support both IPv4 and IPv6.
Returns the first routable and bindable address, or None on failure.
"""
try:
host_name = socket.gethostname()
infos = socket.getaddrinfo(
host_name, None, socket.AF_UNSPEC, socket.SOCK_STREAM
)
for family, socktype, _, _, sockaddr in infos:
host_ip = sockaddr[0]
if not _is_routable(host_ip):
continue
try:
with socket.socket(family, socktype) as s:
s.bind((host_ip, 0))
return host_ip
except OSError:
continue
return None
except OSError as exc:
logger.debug("Hostname resolution failed: %s", exc)
return None
def _try_udp_connect(family: socket.AddressFamily, target: tuple) -> str | None:
"""Use UDP connect to find the outbound interface IP. Returns None on failure.
Args:
family: socket.AF_INET or socket.AF_INET6
target: (address, port) tuple to "connect" to (no data is sent)
"""
try:
with socket.socket(family, socket.SOCK_DGRAM) as s:
s.connect(target)
return s.getsockname()[0]
except OSError as exc:
logger.debug("UDP connect detection failed (family=%s): %s", family, exc)
return None
def ensure_side_channel_host(): def ensure_side_channel_host():
...@@ -492,11 +577,9 @@ def ensure_side_channel_host(): ...@@ -492,11 +577,9 @@ def ensure_side_channel_host():
existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST") existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST")
if existing_host: if existing_host:
logger.debug( logger.info("Using existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host)
"Preserving existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host
)
return return
host_ip = get_host_ip() host_ip = get_host_ip()
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
logger.debug("Set VLLM_NIXL_SIDE_CHANNEL_HOST to %s", host_ip) logger.info("Set VLLM_NIXL_SIDE_CHANNEL_HOST to %s (auto-detected)", host_ip)
...@@ -5,16 +5,21 @@ ...@@ -5,16 +5,21 @@
import json import json
import re import re
import socket
import warnings import warnings
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import patch
import pytest import pytest
from dynamo.vllm.args import ( from dynamo.vllm.args import (
_connector_to_kv_transfer_json, _connector_to_kv_transfer_json,
_is_routable,
_uses_dynamo_connector, _uses_dynamo_connector,
_uses_nixl_connector, _uses_nixl_connector,
ensure_side_channel_host,
get_host_ip,
parse_args, parse_args,
) )
from dynamo.vllm.constants import DisaggregationMode from dynamo.vllm.constants import DisaggregationMode
...@@ -422,3 +427,134 @@ def test_explicit_default_mode_with_legacy_flag_raises(mock_vllm_cli): ...@@ -422,3 +427,134 @@ def test_explicit_default_mode_with_legacy_flag_raises(mock_vllm_cli):
) )
with pytest.raises(ValueError, match="Cannot combine"): with pytest.raises(ValueError, match="Cannot combine"):
parse_args() parse_args()
# --- _is_routable tests (pure logic, no mocking) ---
class TestIsRoutable:
def test_accepts_private_ipv4(self):
assert _is_routable("10.0.0.5") is True
assert _is_routable("192.168.1.1") is True
def test_accepts_private_ipv6(self):
assert _is_routable("fd00::1") is True
def test_rejects_loopback_v4(self):
assert _is_routable("127.0.0.1") is False
def test_rejects_loopback_v6(self):
assert _is_routable("::1") is False
def test_rejects_link_local_v4(self):
assert _is_routable("169.254.1.1") is False
def test_rejects_link_local_v6(self):
assert _is_routable("fe80::1") is False
def test_rejects_unspecified(self):
assert _is_routable("0.0.0.0") is False
assert _is_routable("::") is False
def test_rejects_multicast(self):
assert _is_routable("224.0.0.1") is False
def test_rejects_invalid(self):
assert _is_routable("not-an-ip") is False
# --- get_host_ip tests (mock socket module functions) ---
class TestGetHostIp:
def test_hostname_resolution_success(self):
"""getaddrinfo returns routable IPv4 → returns it."""
with patch(
"dynamo.vllm.args._try_hostname_resolution", return_value="10.0.0.5"
):
result = get_host_ip()
assert result == "10.0.0.5"
def test_hostname_loopback_falls_through_to_udp(self):
"""getaddrinfo returns 127.0.0.1, UDP returns 10.0.0.5 → returns 10.0.0.5."""
with (
patch(
"dynamo.vllm.args._try_hostname_resolution", return_value="127.0.0.1"
),
patch("dynamo.vllm.args._try_udp_connect") as mock_udp,
):
mock_udp.side_effect = lambda family, target: (
"10.0.0.5" if family == socket.AF_INET else None
)
result = get_host_ip()
assert result == "10.0.0.5"
def test_hostname_link_local_falls_through_to_udp(self):
"""getaddrinfo returns 169.254.1.1, UDP returns 10.0.0.5 → returns 10.0.0.5."""
with (
patch(
"dynamo.vllm.args._try_hostname_resolution", return_value="169.254.1.1"
),
patch("dynamo.vllm.args._try_udp_connect") as mock_udp,
):
mock_udp.side_effect = lambda family, target: (
"10.0.0.5" if family == socket.AF_INET else None
)
result = get_host_ip()
assert result == "10.0.0.5"
def test_ipv6_fallback(self):
"""IPv4 strategies fail, IPv6 UDP returns fd00::1 → returns fd00::1."""
with (
patch("dynamo.vllm.args._try_hostname_resolution", return_value=None),
patch("dynamo.vllm.args._try_udp_connect") as mock_udp,
):
mock_udp.side_effect = lambda family, target: (
"fd00::1" if family == socket.AF_INET6 else None
)
result = get_host_ip()
assert result == "fd00::1"
def test_all_fail_raises_runtime_error(self):
"""All strategies fail → RuntimeError with VLLM_NIXL_SIDE_CHANNEL_HOST in message."""
with (
patch("dynamo.vllm.args._try_hostname_resolution", return_value=None),
patch("dynamo.vllm.args._try_udp_connect", return_value=None),
):
with pytest.raises(RuntimeError, match="VLLM_NIXL_SIDE_CHANNEL_HOST"):
get_host_ip()
# --- ensure_side_channel_host tests ---
class TestEnsureSideChannelHost:
def test_preserves_existing_env_var(self, monkeypatch):
"""Pre-set env var → verify not overwritten."""
monkeypatch.setenv("VLLM_NIXL_SIDE_CHANNEL_HOST", "192.168.99.99")
with patch("dynamo.vllm.args.get_host_ip") as mock_get:
ensure_side_channel_host()
mock_get.assert_not_called()
import os
assert os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] == "192.168.99.99"
def test_sets_env_var_on_successful_detection(self, monkeypatch):
"""No env var set, successful detection populates the side-channel host."""
monkeypatch.delenv("VLLM_NIXL_SIDE_CHANNEL_HOST", raising=False)
with patch("dynamo.vllm.args.get_host_ip", return_value="10.0.0.5"):
ensure_side_channel_host()
import os
assert os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] == "10.0.0.5"
def test_raises_when_detection_fails_and_no_env(self, monkeypatch):
"""All strategies fail, no env var → RuntimeError."""
monkeypatch.delenv("VLLM_NIXL_SIDE_CHANNEL_HOST", raising=False)
with patch(
"dynamo.vllm.args.get_host_ip",
side_effect=RuntimeError("Unable to determine"),
):
with pytest.raises(RuntimeError, match="Unable to determine"):
ensure_side_channel_host()
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
import ipaddress
import logging import logging
import os import os
import socket import socket
...@@ -116,24 +117,84 @@ def get_kv_port() -> int: ...@@ -116,24 +117,84 @@ def get_kv_port() -> int:
def ensure_side_channel_host(): def ensure_side_channel_host():
"""Ensure the NIXL side-channel host is available without overriding user settings.""" """Ensure the NIXL side-channel host is available without overriding user settings.
Uses hostname resolution with UDP connect fallback. Supports IPv4 and IPv6.
Raises RuntimeError if no routable IP can be determined.
"""
existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST") existing_host = os.getenv("VLLM_NIXL_SIDE_CHANNEL_HOST")
if existing_host: if existing_host:
logger.debug( logger.info("Using existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host)
"Preserving existing VLLM_NIXL_SIDE_CHANNEL_HOST=%s", existing_host
)
return return
def is_routable(ip_str: str) -> bool:
try:
addr = ipaddress.ip_address(ip_str)
return not (
addr.is_loopback
or addr.is_link_local
or addr.is_unspecified
or addr.is_multicast
)
except ValueError:
return False
# Strategy 1: hostname resolution (AF_UNSPEC for IPv4+IPv6)
host_ip = None
detection_method = None
try: try:
host_name = socket.gethostname() host_name = socket.gethostname()
host_ip = socket.gethostbyname(host_name) infos = socket.getaddrinfo(
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as test_socket: host_name, None, socket.AF_UNSPEC, socket.SOCK_STREAM
test_socket.bind((host_ip, 0)) )
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip for family, socktype, _, _, sockaddr in infos:
logger.debug("Set VLLM_NIXL_SIDE_CHANNEL_HOST to %s", host_ip) candidate = sockaddr[0]
except (socket.error, socket.gaierror): try:
logger.warning("Failed to get hostname, falling back to 127.0.0.1") with socket.socket(family, socktype) as s:
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = "127.0.0.1" s.bind((candidate, 0))
if is_routable(candidate):
host_ip = candidate
detection_method = "hostname resolution"
break
except OSError:
continue
except OSError as exc:
logger.debug("Hostname resolution failed: %s", exc)
# Strategy 2: UDP connect trick (IPv4 then IPv6)
if not host_ip:
for family, target, label in [
(socket.AF_INET, ("8.8.8.8", 80), "outbound interface detection (IPv4)"),
(
socket.AF_INET6,
("2001:4860:4860::8888", 80),
"outbound interface detection (IPv6)",
),
]:
try:
with socket.socket(family, socket.SOCK_DGRAM) as s:
s.connect(target)
candidate = s.getsockname()[0]
if is_routable(candidate):
host_ip = candidate
detection_method = label
break
except OSError:
continue
if not host_ip:
raise RuntimeError(
"Unable to determine a routable host IP for NIXL side-channel. "
"Please set the VLLM_NIXL_SIDE_CHANNEL_HOST environment variable to "
"the IP address that peer nodes can reach this host on."
)
os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = host_ip
logger.info(
"Set VLLM_NIXL_SIDE_CHANNEL_HOST=%s (detected via %s)",
host_ip,
detection_method,
)
def configure_ports(config: Config): def configure_ports(config: Config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment