Unverified Commit 56b991b1 authored by Jimmy's avatar Jimmy Committed by GitHub
Browse files

[Feature]feat(get_ip): unify get_ip_xxx (#10081)

parent 780d6a22
...@@ -13,7 +13,7 @@ from sglang.srt.disaggregation.mooncake.conn import ( ...@@ -13,7 +13,7 @@ from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVReceiver, MooncakeKVReceiver,
MooncakeKVSender, MooncakeKVSender,
) )
from sglang.srt.utils import get_local_ip_by_remote from sglang.srt.utils import get_local_ip_auto
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) ...@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class AscendKVManager(MooncakeKVManager): class AscendKVManager(MooncakeKVManager):
def init_engine(self): def init_engine(self):
# TransferEngine initialized on ascend. # TransferEngine initialized on ascend.
local_ip = get_local_ip_by_remote() local_ip = get_local_ip_auto()
self.engine = AscendTransferEngine( self.engine = AscendTransferEngine(
hostname=local_ip, hostname=local_ip,
npu_id=self.kv_args.gpu_id, npu_id=self.kv_args.gpu_id,
......
...@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore ...@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
from sglang.srt.utils import ( from sglang.srt.utils import (
format_tcp_address, format_tcp_address,
get_ip, get_local_ip_auto,
get_open_port, get_open_port,
is_valid_ipv6_address, is_valid_ipv6_address,
) )
...@@ -191,7 +191,9 @@ class MessageQueue: ...@@ -191,7 +191,9 @@ class MessageQueue:
self.n_remote_reader = n_remote_reader self.n_remote_reader = n_remote_reader
if connect_ip is None: if connect_ip is None:
connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1" connect_ip = (
get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
)
context = Context() context = Context()
......
...@@ -2005,48 +2005,11 @@ def set_uvicorn_logging_configs(): ...@@ -2005,48 +2005,11 @@ def set_uvicorn_logging_configs():
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
def get_ip() -> str: def get_ip() -> Optional[str]:
# SGLANG_HOST_IP env can be ignore
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
if host_ip: if host_ip:
return host_ip return host_ip
return None
# IP is not set, try to get it from the network interface
# try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try ipv6
try:
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Google's public DNS server, see
# https://developers.google.com/speed/public-dns/docs/using#addresses
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
except Exception:
pass
# try using hostname
hostname = socket.gethostname()
try:
ip_addr = socket.gethostbyname(hostname)
warnings.warn("using local ip address: {}".format(ip_addr))
return ip_addr
except Exception:
pass
warnings.warn(
"Failed to get the IP address, using 0.0.0.0 by default."
"The value can be set by the environment variable"
" SGLANG_HOST_IP or HOST_IP.",
stacklevel=2,
)
return "0.0.0.0"
def get_open_port() -> int: def get_open_port() -> int:
...@@ -2305,16 +2268,9 @@ def bind_or_assign(target, source): ...@@ -2305,16 +2268,9 @@ def bind_or_assign(target, source):
return source return source
def get_local_ip_auto() -> str: def get_local_ip_by_nic(interface: str = None) -> Optional[str]:
interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None) if not (interface := interface or os.environ.get("SGLANG_LOCAL_IP_NIC", None)):
return ( return None
get_local_ip_by_nic(interface)
if interface is not None
else get_local_ip_by_remote()
)
def get_local_ip_by_nic(interface: str) -> str:
try: try:
import netifaces import netifaces
except ImportError as e: except ImportError as e:
...@@ -2335,15 +2291,13 @@ def get_local_ip_by_nic(interface: str) -> str: ...@@ -2335,15 +2291,13 @@ def get_local_ip_by_nic(interface: str) -> str:
if ip and not ip.startswith("fe80::") and ip != "::1": if ip and not ip.startswith("fe80::") and ip != "::1":
return ip.split("%")[0] return ip.split("%")[0]
except (ValueError, OSError) as e: except (ValueError, OSError) as e:
raise ValueError( logger.warning(
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly." f"{e} Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
) )
return None
# Fallback
return get_local_ip_by_remote()
def get_local_ip_by_remote() -> str: def get_local_ip_by_remote() -> Optional[str]:
# try ipv4 # try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try: try:
...@@ -2368,7 +2322,49 @@ def get_local_ip_by_remote() -> str: ...@@ -2368,7 +2322,49 @@ def get_local_ip_by_remote() -> str:
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
return s.getsockname()[0] return s.getsockname()[0]
except Exception: except Exception:
raise ValueError("Can not get local ip") logger.warning("Can not get local ip by remote")
return None
def get_local_ip_auto(fallback: str = None) -> str:
"""
Automatically detect the local IP address using multiple fallback strategies.
This function attempts to obtain the local IP address through several methods.
If all methods fail, it returns the specified fallback value or raises an exception.
Args:
fallback (str, optional): Fallback IP address to return if all detection
methods fail. For server applications, explicitly set this to
"0.0.0.0" (IPv4) or "::" (IPv6) to bind to all available interfaces.
Defaults to None.
Returns:
str: The detected local IP address, or the fallback value if detection fails.
Raises:
ValueError: If IP detection fails and no fallback value is provided.
Note:
The function tries detection methods in the following order:
1. Direct IP detection via get_ip()
2. Network interface enumeration via get_local_ip_by_nic()
3. Remote connection method via get_local_ip_by_remote()
"""
if ip := get_ip():
return ip
logger.debug("get_ip failed")
# Fallback
if ip := get_local_ip_by_nic():
return ip
logger.debug("get_local_ip_by_nic failed")
# Fallback
if ip := get_local_ip_by_remote():
return ip
logger.debug("get_local_ip_by_remote failed")
if fallback:
return fallback
raise ValueError("Can not get local ip")
def is_page_size_one(server_args): def is_page_size_one(server_args):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment