Unverified Commit 56b991b1 authored by Jimmy's avatar Jimmy Committed by GitHub
Browse files

[Feature]feat(get_ip): unify get_ip_xxx (#10081)

parent 780d6a22
......@@ -13,7 +13,7 @@ from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVReceiver,
MooncakeKVSender,
)
from sglang.srt.utils import get_local_ip_by_remote
from sglang.srt.utils import get_local_ip_auto
logger = logging.getLogger(__name__)
......@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class AscendKVManager(MooncakeKVManager):
def init_engine(self):
# TransferEngine initialized on ascend.
local_ip = get_local_ip_by_remote()
local_ip = get_local_ip_auto()
self.engine = AscendTransferEngine(
hostname=local_ip,
npu_id=self.kv_args.gpu_id,
......
......@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
from sglang.srt.utils import (
format_tcp_address,
get_ip,
get_local_ip_auto,
get_open_port,
is_valid_ipv6_address,
)
......@@ -191,7 +191,9 @@ class MessageQueue:
self.n_remote_reader = n_remote_reader
if connect_ip is None:
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
connect_ip = (
get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
)
context = Context()
......
......@@ -2005,48 +2005,11 @@ def set_uvicorn_logging_configs():
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
def get_ip() -> str:
# SGLANG_HOST_IP env can be ignore
def get_ip() -> Optional[str]:
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
if host_ip:
return host_ip
# IP is not set, try to get it from the network interface
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try using hostname
hostname = socket.gethostname()
try:
ip_addr = socket.gethostbyname(hostname)
warnings.warn("using local ip address: {}".format(ip_addr))
return ip_addr
except Exception:
pass
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable"
" SGLANG_HOST_IP or HOST_IP.",
stacklevel=2,
)
return "0.0.0.0"
return None
def get_open_port() -> int:
......@@ -2305,16 +2268,9 @@ def bind_or_assign(target, source):
return source
def get_local_ip_auto() -> str:
interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
return (
get_local_ip_by_nic(interface)
if interface is not None
else get_local_ip_by_remote()
)
def get_local_ip_by_nic(interface: str) -> str:
def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
return None
try:
import netifaces
except ImportError as e:
......@@ -2335,15 +2291,13 @@ def get_local_ip_by_nic(interface: str) -> str:
if ip and not ip.startswith("fe80::") and ip != "::1":
return ip.split("%")[0]
except (ValueError, OSError) as e:
raise ValueError(
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
logger.warning(
f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
)
# Fallback
return get_local_ip_by_remote()
return None
def get_local_ip_by_remote() -> str:
def get_local_ip_by_remote() -> Optional[str]:
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
......@@ -2368,7 +2322,49 @@ def get_local_ip_by_remote() -> str:
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
raise ValueError("Can not get local ip")
logger.warning("Can not get local ip by remote")
return None
def get_local_ip_auto(fallback: str = None) -> str:
"""
Automatically detect the local IP address using multiple fallback strategies.
This function attempts to obtain the local IP address through several methods.
If all methods fail, it returns the specified fallback value or raises an exception.
Args:
fallback (str, optional): Fallback IP address to return if all detection
methods fail. For server applications, explicitly set this to
"0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
Defaults to None.
Returns:
str: The detected local IP address, or the fallback value if detection fails.
Raises:
ValueError: If IP detection fails and no fallback value is provided.
Note:
The function tries detection methods in the following order:
1. Direct IP detection via get_ip()
2. Network interface enumeration via get_local_ip_by_nic()
3. Remote connection method via get_local_ip_by_remote()
"""
if ip := get_ip():
return ip
logger.debug("get_ip failed")
# Fallback
if ip := get_local_ip_by_nic():
return ip
logger.debug("get_local_ip_by_nic failed")
# Fallback
if ip := get_local_ip_by_remote():
return ip
logger.debug("get_local_ip_by_remote failed")
if fallback:
return fallback
raise ValueError("Can not get local ip")
def is_page_size_one(server_args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment