"tests/compute/test_partition.py" did not exist on "75fca8e4e8bdeaffc6f2173d46eae42c4e242977"
Unverified Commit 59dd090f authored by ybyang's avatar ybyang Committed by GitHub
Browse files

[PD] Fix no cache connect for recevier (#5534)

parent 569b032c
...@@ -387,6 +387,10 @@ class MooncakeKVSender(BaseKVSender): ...@@ -387,6 +387,10 @@ class MooncakeKVSender(BaseKVSender):
class MooncakeKVReceiver(BaseKVReceiver): class MooncakeKVReceiver(BaseKVReceiver):
_ctx = zmq.Context()
_socket_cache = {}
_socket_locks = {}
_global_lock = threading.Lock()
def __init__( def __init__(
self, self,
...@@ -436,11 +440,15 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -436,11 +440,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error(f"Error fetching prefill info from bootstrap: {e}") logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None return None
@cache @classmethod
def _connect(self, endpoint: str): def _connect(cls, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH) with cls._global_lock:
socket.connect(endpoint) if endpoint not in cls._socket_cache:
return socket sock = cls._ctx.socket(zmq.PUSH)
sock.connect(endpoint)
cls._socket_cache[endpoint] = sock
cls._socket_locks[endpoint] = threading.Lock()
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
self.prefill_server_url = ( self.prefill_server_url = (
...@@ -456,18 +464,20 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -456,18 +464,20 @@ class MooncakeKVReceiver(BaseKVReceiver):
packed_aux_data_ptrs = b"".join( packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
) )
self._connect("tcp://" + self.prefill_server_url).send_multipart( sock, lock = self._connect("tcp://" + self.prefill_server_url)
[ with lock:
str(self.bootstrap_room).encode("ascii"), sock.send_multipart(
get_local_ip_by_remote().encode("ascii"), [
str(self.kv_mgr.rank_port).encode("ascii"), str(self.bootstrap_room).encode("ascii"),
self.session_id.encode("ascii"), get_local_ip_by_remote().encode("ascii"),
packed_kv_data_ptrs, str(self.kv_mgr.rank_port).encode("ascii"),
kv_indices.tobytes(), self.session_id.encode("ascii"),
packed_aux_data_ptrs, packed_kv_data_ptrs,
str(aux_index).encode("ascii"), kv_indices.tobytes(),
] packed_aux_data_ptrs,
) str(aux_index).encode("ascii"),
]
)
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
return self.kv_mgr.check_status(self.bootstrap_room) return self.kv_mgr.check_status(self.bootstrap_room)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment