"tests/vscode:/vscode.git/clone" did not exist on "8ad68c13938dd534c4888b884f2747063317d2cf"
Unverified Commit ceaa85c9 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Support get local ip from NIC for PD disaggregation (#7237)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 0650e517
...@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import ( ...@@ -35,12 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
get_free_port,
get_int_env_var,
get_ip,
get_local_ip_by_remote,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager):
is_mla_backend: Optional[bool] = False, is_mla_backend: Optional[bool] = False,
): ):
self.kv_args = args self.kv_args = args
self.local_ip = get_local_ip_auto()
self.engine = MooncakeTransferEngine( self.engine = MooncakeTransferEngine(
hostname=get_local_ip_by_remote(), hostname=self.local_ip,
gpu_id=self.kv_args.gpu_id, gpu_id=self.kv_args.gpu_id,
ib_device=self.kv_args.ib_device, ib_device=self.kv_args.ib_device,
) )
...@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager):
def start_prefill_thread(self): def start_prefill_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
def bootstrap_thread(): def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine""" """This thread recvs pre-alloc notification from the decode engine"""
...@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager):
def start_decode_thread(self): def start_decode_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
def decode_thread(): def decode_thread():
while True: while True:
...@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager):
"role": "Prefill", "role": "Prefill",
"tp_size": self.tp_size, "tp_size": self.tp_size,
"dp_size": self.dp_size, "dp_size": self.dp_size,
"rank_ip": get_local_ip_by_remote(), "rank_ip": self.local_ip,
"rank_port": self.rank_port, "rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank, "engine_rank": self.kv_args.engine_rank,
} }
...@@ -953,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -953,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
sock.send_multipart( sock.send_multipart(
[ [
"None".encode("ascii"), "None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"), self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
packed_kv_data_ptrs, packed_kv_data_ptrs,
...@@ -983,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -983,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
sock.send_multipart( sock.send_multipart(
[ [
str(self.bootstrap_room).encode("ascii"), str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"), self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"", kv_indices.tobytes() if not is_dummy else b"",
......
...@@ -2141,6 +2141,44 @@ def get_free_port(): ...@@ -2141,6 +2141,44 @@ def get_free_port():
return s.getsockname()[1] return s.getsockname()[1]
def get_local_ip_auto() -> str:
interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
return (
get_local_ip_by_nic(interface)
if interface is not None
else get_local_ip_by_remote()
)
def get_local_ip_by_nic(interface: str) -> str:
try:
import netifaces
except ImportError as e:
raise ImportError(
"Environment variable SGLANG_LOCAL_IP_NIC requires package netifaces, please install it through 'pip install netifaces'"
) from e
try:
addresses = netifaces.ifaddresses(interface)
if netifaces.AF_INET in addresses:
for addr_info in addresses[netifaces.AF_INET]:
ip = addr_info.get("addr")
if ip and ip != "127.0.0.1" and ip != "0.0.0.0":
return ip
if netifaces.AF_INET6 in addresses:
for addr_info in addresses[netifaces.AF_INET6]:
ip = addr_info.get("addr")
if ip and not ip.startswith("fe80::") and ip != "::1":
return ip.split("%")[0]
except (ValueError, OSError) as e:
raise ValueError(
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
)
# Fallback
return get_local_ip_by_remote()
def get_local_ip_by_remote() -> str: def get_local_ip_by_remote() -> str:
# try ipv4 # try ipv4
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment