Unverified Commit 1b9cea5a authored by Stepan Kargaltsev's avatar Stepan Kargaltsev Committed by GitHub
Browse files

[P/D] Support ipv6 in P/D scenario (#7858)


Co-authored-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 9045cc1e
...@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import ( ...@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
) )
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_ip,
get_local_ip_by_remote,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager): ...@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
def _register_to_bootstrap(self): def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST.""" """Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr: if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0]) if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else: else:
ip_address = get_ip() host = get_ip()
host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}" bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route" url = f"http://{bootstrap_server_url}/route"
payload = { payload = {
"role": "Prefill", "role": "Prefill",
...@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager): ...@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
logger.error(f"Prefill Failed to register to bootstrap server: {e}") logger.error(f"Prefill Failed to register to bootstrap server: {e}")
@cache @cache
def _connect(self, endpoint: str): def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH) socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint) socket.connect(endpoint)
return socket return socket
...@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver): ...@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
return None return None
@classmethod @classmethod
def _connect(cls, endpoint: str): def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock: with cls._global_lock:
if endpoint not in cls._socket_cache: if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH) sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint) sock.connect(endpoint)
cls._socket_cache[endpoint] = sock cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock() cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint] return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock
def _register_kv_args(self): def _register_kv_args(self):
pass pass
......
...@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException ...@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
from fastapi.responses import ORJSONResponse, Response, StreamingResponse from fastapi.responses import ORJSONResponse, Response, StreamingResponse
from sglang.srt.disaggregation.utils import PDRegistryRequest from sglang.srt.disaggregation.utils import PDRegistryRequest
from sglang.srt.utils import maybe_wrap_ipv6_address
AIOHTTP_STREAM_READ_CHUNK_SIZE = ( AIOHTTP_STREAM_READ_CHUNK_SIZE = (
1024 * 64 1024 * 64
...@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict): ...@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
# Parse and transform prefill_server for bootstrap data # Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server) parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy() modified_request = request_data.copy()
batch_size = _get_request_batch_size(modified_request) batch_size = _get_request_batch_size(modified_request)
...@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str): ...@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
# Parse and transform prefill_server for bootstrap data # Parse and transform prefill_server for bootstrap data
parsed_url = urllib.parse.urlparse(prefill_server) parsed_url = urllib.parse.urlparse(prefill_server)
hostname = parsed_url.hostname hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
modified_request = request_data.copy() modified_request = request_data.copy()
modified_request.update( modified_request.update(
{ {
......
...@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import ( ...@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto from sglang.srt.utils import (
format_tcp_address,
get_free_port,
get_int_env_var,
get_ip,
get_local_ip_auto,
is_valid_ipv6_address,
maybe_wrap_ipv6_address,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
self.request_status: Dict[int, KVPoll] = {} self.request_status: Dict[int, KVPoll] = {}
self.rank_port = None self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine() self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {} self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
...@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager): ...@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
self.engine.register(aux_data_ptr, aux_data_len) self.engine.register(aux_data_ptr, aux_data_len)
@cache @cache
def _connect(self, endpoint: str): def _connect(self, endpoint: str, is_ipv6: bool = False):
socket = zmq.Context().socket(zmq.PUSH) socket = zmq.Context().socket(zmq.PUSH)
if is_ipv6:
socket.setsockopt(zmq.IPV6, 1)
socket.connect(endpoint) socket.connect(endpoint)
return socket return socket
...@@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -471,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
def sync_status_to_decode_endpoint( def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
): ):
if ":" in remote: self._connect(
remote = remote.split(":")[0] format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart( ).send_multipart(
[ [
str(room).encode("ascii"), str(room).encode("ascii"),
str(status).encode("ascii"), str(status).encode("ascii"),
...@@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager): ...@@ -616,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead." f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
) )
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def start_prefill_thread(self): def start_prefill_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}") self._bind_server_socket()
def bootstrap_thread(): def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine""" """This thread recvs pre-alloc notification from the decode engine"""
...@@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -657,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
def start_decode_thread(self): def start_decode_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}") self._bind_server_socket()
def decode_thread(): def decode_thread():
while True: while True:
...@@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -776,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
# requests with the same dst_sessions will be added into the same # requests with the same dst_sessions will be added into the same
# queue, which enables early abort with failed sessions. # queue, which enables early abort with failed sessions.
dst_infos = self.transfer_infos[bootstrap_room].keys() dst_infos = self.transfer_infos[bootstrap_room].keys()
session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos) session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
shard_idx = session_port_sum % len(self.transfer_queues) shard_idx = session_port_sum % len(self.transfer_queues)
self.transfer_queues[shard_idx].put( self.transfer_queues[shard_idx].put(
...@@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager): ...@@ -814,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
def _register_to_bootstrap(self): def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST.""" """Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr: if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0]) if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
if self.dist_init_addr.endswith("]"):
host = self.dist_init_addr
else:
host, _ = self.dist_init_addr.rsplit(":", 1)
else:
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
else: else:
ip_address = get_ip() host = get_ip()
host = maybe_wrap_ipv6_address(host)
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}" bootstrap_server_url = f"{host}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route" url = f"http://{bootstrap_server_url}/route"
payload = { payload = {
"role": "Prefill", "role": "Prefill",
...@@ -1163,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1163,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
def _register_kv_args(self): def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
) )
...@@ -1179,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1179,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
dst_tp_size = str(tp_size).encode("ascii") dst_tp_size = str(tp_size).encode("ascii")
dst_kv_item_len = str(kv_item_len).encode("ascii") dst_kv_item_len = str(kv_item_len).encode("ascii")
sock, lock = self._connect("tcp://" + self.prefill_server_url) sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
with lock: with lock:
sock.send_multipart( sock.send_multipart(
[ [
...@@ -1196,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1196,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
) )
@classmethod @classmethod
def _connect(cls, endpoint: str): def _connect(cls, endpoint: str, is_ipv6: bool = False):
with cls._global_lock: with cls._global_lock:
if endpoint not in cls._socket_cache: if endpoint not in cls._socket_cache:
sock = cls._ctx.socket(zmq.PUSH) sock = cls._ctx.socket(zmq.PUSH)
if is_ipv6:
sock.setsockopt(zmq.IPV6, 1)
sock.connect(endpoint) sock.connect(endpoint)
cls._socket_cache[endpoint] = sock cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock() cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint] return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
@classmethod
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
ip_address = bootstrap_info["rank_ip"]
port = bootstrap_info["rank_port"]
is_ipv6_address = is_valid_ipv6_address(ip_address)
sock, lock = cls._connect(
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
)
return sock, lock
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = ( sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
is_dummy = bootstrap_info["is_dummy"] is_dummy = bootstrap_info["is_dummy"]
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock: with lock:
sock.send_multipart( sock.send_multipart(
[ [
......
import logging import logging
from typing import List, Optional from typing import List, Optional
from sglang.srt.utils import get_bool_env_var, get_free_port from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -27,7 +27,9 @@ class MooncakeTransferEngine: ...@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
hostname=self.hostname, hostname=self.hostname,
device_name=self.ib_device, device_name=self.ib_device,
) )
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}" self.session_id = (
f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
)
def register(self, ptr, length): def register(self, ptr, length):
try: try:
......
...@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import ( ...@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_local_ip_by_remote from sglang.srt.utils import (
format_tcp_address,
get_local_ip_auto,
is_valid_ipv6_address,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager): ...@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
"to run SGLang with NixlTransferEngine." "to run SGLang with NixlTransferEngine."
) from e ) from e
self.agent = nixl_agent(str(uuid.uuid4())) self.agent = nixl_agent(str(uuid.uuid4()))
self.local_ip = get_local_ip_auto()
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
if is_valid_ipv6_address(self.local_ip):
self.server_socket.setsockopt(zmq.IPV6, 1)
self.register_buffer_to_engine() self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
...@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager): ...@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
return False return False
return self.transfer_statuses[room].is_done() return self.transfer_statuses[room].is_done()
def _bind_server_socket(self):
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
def _start_bootstrap_thread(self): def _start_bootstrap_thread(self):
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") self._bind_server_socket()
def bootstrap_thread(): def bootstrap_thread():
"""This thread recvs transfer info from the decode engine""" """This thread recvs transfer info from the decode engine"""
...@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
logger.debug( logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
) )
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"] is_dummy = bootstrap_info["is_dummy"]
logger.debug( logger.debug(
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}" f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
) )
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock: with lock:
sock.send_multipart( sock.send_multipart(
[ [
GUARD, GUARD,
str(self.bootstrap_room).encode("ascii"), str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"), self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"), self.kv_mgr.agent.name.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"", kv_indices.tobytes() if not is_dummy else b"",
...@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
def _register_kv_args(self): def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
self.prefill_server_url = ( sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
) )
...@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
) )
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock: with lock:
sock.send_multipart( sock.send_multipart(
[ [
GUARD, GUARD,
"None".encode("ascii"), "None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"), self.kv_mgr.local_ip.encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.name.encode("ascii"), self.kv_mgr.agent.name.encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(), self.kv_mgr.agent.get_agent_metadata(),
......
...@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup ...@@ -16,7 +16,12 @@ from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address from sglang.srt.utils import (
format_tcp_address,
get_ip,
get_open_port,
is_valid_ipv6_address,
)
# SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60 # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60
SGLANG_RINGBUFFER_WARNING_INTERVAL = int( SGLANG_RINGBUFFER_WARNING_INTERVAL = int(
...@@ -225,9 +230,9 @@ class MessageQueue: ...@@ -225,9 +230,9 @@ class MessageQueue:
remote_subscribe_port = get_open_port() remote_subscribe_port = get_open_port()
if is_valid_ipv6_address(connect_ip): if is_valid_ipv6_address(connect_ip):
self.remote_socket.setsockopt(IPV6, 1) self.remote_socket.setsockopt(IPV6, 1)
connect_ip = f"[{connect_ip}]" self.remote_socket.bind(
socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}" format_tcp_address(connect_ip, remote_subscribe_port)
self.remote_socket.bind(socket_addr) )
else: else:
remote_subscribe_port = None remote_subscribe_port = None
...@@ -288,7 +293,9 @@ class MessageQueue: ...@@ -288,7 +293,9 @@ class MessageQueue:
self.remote_socket.setsockopt_string(SUBSCRIBE, "") self.remote_socket.setsockopt_string(SUBSCRIBE, "")
if is_valid_ipv6_address(handle.connect_ip): if is_valid_ipv6_address(handle.connect_ip):
self.remote_socket.setsockopt(IPV6, 1) self.remote_socket.setsockopt(IPV6, 1)
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" socket_addr = format_tcp_address(
handle.connect_ip, handle.remote_subscribe_port
)
logger.debug("Connecting to %s", socket_addr) logger.debug("Connecting to %s", socket_addr)
self.remote_socket.connect(socket_addr) self.remote_socket.connect(socket_addr)
......
...@@ -2065,6 +2065,16 @@ def is_valid_ipv6_address(address: str) -> bool: ...@@ -2065,6 +2065,16 @@ def is_valid_ipv6_address(address: str) -> bool:
return False return False
def maybe_wrap_ipv6_address(address: str) -> str:
if is_valid_ipv6_address(address):
return f"[{address}]"
return address
def format_tcp_address(ip: str, port: int) -> str:
return f"tcp://{maybe_wrap_ipv6_address(ip)}:{port}"
def configure_ipv6(dist_init_addr): def configure_ipv6(dist_init_addr):
addr = dist_init_addr addr = dist_init_addr
end = addr.find("]") end = addr.find("]")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment