Unverified Commit a5450f11 authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[Security] Use safe serialization and fix zmq setup for mooncake pipe (#17192)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 9d98ab5e
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import json import json
import os import os
import struct
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
...@@ -115,14 +116,14 @@ class MooncakeTransferEngine: ...@@ -115,14 +116,14 @@ class MooncakeTransferEngine:
p_rank_offset = int(p_port) + 8 + self.local_rank * 2 p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2 d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0: if kv_rank == 0:
self.sender_socket.bind(f"tcp://*:{p_rank_offset + 1}") self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://*:{p_rank_offset + 2}") self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
else: else:
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://*:{d_rank_offset + 1}") self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://*:{d_rank_offset + 2}") self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
def initialize(self, local_hostname: str, metadata_server: str, def initialize(self, local_hostname: str, metadata_server: str,
...@@ -176,7 +177,7 @@ class MooncakeTransferEngine: ...@@ -176,7 +177,7 @@ class MooncakeTransferEngine:
def wait_for_ack(self, src_ptr: int, length: int) -> None: def wait_for_ack(self, src_ptr: int, length: int) -> None:
"""Asynchronously wait for ACK from the receiver.""" """Asynchronously wait for ACK from the receiver."""
ack = self.sender_ack.recv_pyobj() ack = self.sender_ack.recv()
if ack != b'ACK': if ack != b'ACK':
logger.error("Failed to receive ACK from the receiver") logger.error("Failed to receive ACK from the receiver")
...@@ -187,18 +188,22 @@ class MooncakeTransferEngine: ...@@ -187,18 +188,22 @@ class MooncakeTransferEngine:
length = len(user_data) length = len(user_data)
src_ptr = self.allocate_managed_buffer(length) src_ptr = self.allocate_managed_buffer(length)
self.write_bytes_to_buffer(src_ptr, user_data, length) self.write_bytes_to_buffer(src_ptr, user_data, length)
self.sender_socket.send_pyobj((src_ptr, length)) self.sender_socket.send_multipart(
[struct.pack("!Q", src_ptr),
struct.pack("!Q", length)])
self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length) self.buffer_cleaner.submit(self.wait_for_ack, src_ptr, length)
def recv_bytes(self) -> bytes: def recv_bytes(self) -> bytes:
"""Receive bytes from the remote process.""" """Receive bytes from the remote process."""
src_ptr, length = self.receiver_socket.recv_pyobj() data = self.receiver_socket.recv_multipart()
src_ptr = struct.unpack("!Q", data[0])[0]
length = struct.unpack("!Q", data[1])[0]
dst_ptr = self.allocate_managed_buffer(length) dst_ptr = self.allocate_managed_buffer(length)
self.transfer_sync(dst_ptr, src_ptr, length) self.transfer_sync(dst_ptr, src_ptr, length)
ret = self.read_bytes_from_buffer(dst_ptr, length) ret = self.read_bytes_from_buffer(dst_ptr, length)
# Buffer cleanup # Buffer cleanup
self.receiver_ack.send_pyobj(b'ACK') self.receiver_ack.send(b'ACK')
self.free_managed_buffer(dst_ptr, length) self.free_managed_buffer(dst_ptr, length)
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment