Unverified Commit 1b9cea5a authored by Stepan Kargaltsev's avatar Stepan Kargaltsev Committed by GitHub
Browse files

[P/D] Support ipv6 in P/D scenario (#7858)


Co-authored-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 9045cc1e
......@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_ip,
get_local_ip_by_remote,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
logger = logging.getLogger(__name__)
......@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
ip_address = get_ip()
host = get_ip()
host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
......@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
@cache
def _connect(self, endpoint: str):
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
......@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
return None
@classmethod
def _connect(cls, endpoint: str):
def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock
def _register_kv_args(self):
pass
......
......@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import PDRegistryRequest
from sglang.srt.utils import maybe_wrap_ipv6_address
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64
......@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
batch_size = _get_request_batch_size(modified_request)
......@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
# Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy()
modified_request.update(
{
......
......@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_int_env_var,
get_ip,
get_local_ip_auto,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
logger = logging.getLogger(__name__)
......@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
self.request_status: Dict[int, KVPoll] = {}
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
......@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
self.engine.register(aux_data_ptr, aux_data_len)
@cache
def _connect(self, endpoint: str):
def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint)
return socket
......@@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
):
if ":" in remote:
remote = remote.split(":")[0]
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
self._connect(
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
).send_multipart(
[
str(room).encode("ascii"),
str(status).encode("ascii"),
......@@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def start_prefill_thread(self):
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
self._bind_server_socket()
def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
......@@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
def start_decode_thread(self):
self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
self._bind_server_socket()
def decode_thread():
while True:
......@@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
# requests with the same dst_sessions will be added into the same
# queue, which enables early abort with failed sessions.
dst_infos = self.transfer_infos[bootstrap_room].keys()
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
shard_idx = session_port_sum % len(self.transfer_queues)
self.transfer_queues[shard_idx].put(
......@@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else:
ip_address = get_ip()
host = get_ip()
host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
......@@ -1163,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
......@@ -1179,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
dst_tp_size = str(tp_size).encode("ascii")
dst_kv_item_len = str(kv_item_len).encode("ascii")
sock, lock = self._connect("tcp://" + self.prefill_server_url)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
with lock:
sock.send_multipart(
[
......@@ -1196,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
)
@classmethod
def _connect(cls, endpoint: str):
def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock:
if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"]
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
......
import logging
from typing import List, Optional
from sglang.srt.utils import get_bool_env_var, get_free_port
from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
logger = logging.getLogger(__name__)
......@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
hostname=self.hostname,
device_name=self.ib_device,
)
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
self.session_id = (
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
)
def register(self, ptr, length):
try:
......
......@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_by_remote
from sglang.srt.utils import (
format_tcp_address,
get_local_ip_auto,
is_valid_ipv6_address,
)
logger = logging.getLogger(__name__)
......@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
"to run SGLang with NixlTransferEngine."
) from e
self.agent = nixl_agent(str(uuid.uuid4()))
self.local_ip = get_local_ip_auto()
self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
......@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
return False
return self.transfer_statuses[room].is_done()
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def _start_bootstrap_thread(self):
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
self._bind_server_socket()
def bootstrap_thread():
"""This thread recvs transfer info from the decode engine"""
......@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"]
logger.debug(
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
)
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"",
......@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
......@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
"None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
......
......@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address
from sglang.srt.utils import (
format_tcp_address,
get_ip,
get_open_port,
is_valid_ipv6_address,
)
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
......@@ -225,9 +230,9 @@ class MessageQueue:
remote_subscribe_port = get_open_port()
if is_valid_ipv6_address(connect_ip):
self.remote_socket.setsockopt(IPV6, 1)
connect_ip = f"[{connect_ip}]"
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr)
self.remote_socket.bind(
format_tcp_address(connect_ip, remote_subscribe_port)
)
else:
remote_subscribe_port = None
......@@ -288,7 +293,9 @@ class MessageQueue:
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
if is_valid_ipv6_address(handle.connect_ip):
self.remote_socket.setsockopt(IPV6, 1)
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
socket_addr = format_tcp_address(
handle.connect_ip, handle.remote_subscribe_port
)
logger.debug("Connecting to %s", socket_addr)
self.remote_socket.connect(socket_addr)
......
......@@ -2065,6 +2065,16 @@ def is_valid_ipv6_address(address: str) -> bool:
return False
def maybe_wrap_ipv6_address(address: str) -> str:
if is_valid_ipv6_address(address):
return f"[{address}]"
return address
def format_tcp_address(ip: str, port: int) -> str:
return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}"
def configure_ipv6(dist_init_addr):
addr = dist_init_addr
end = addr.find("]")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment