Unverified Commit c459536b authored by dongmao zhang's avatar dongmao zhang Committed by GitHub
Browse files

[PD] bug fix: Update status if nixl receiver send a a dummy req. (#6720)

parent 535c8386
......@@ -53,39 +53,23 @@ class TransferInfo:
required_dst_info_num: int
def is_dummy(self):
return self.endpoint == ""
return self.dst_kv_indices.size == 0
@classmethod
def from_zmq(cls, msg: List[bytes]):
if len(msg) == 1:
# dummy msg
return cls(
room=int(msg[0].decode("ascii")),
endpoint="",
dst_port=0,
agent_metadata=b"",
agent_name="",
dst_kv_ptrs=[],
dst_kv_indices=np.array([], dtype=np.int64),
dst_aux_ptrs=[],
dst_aux_index=0,
dst_gpu_id=0,
required_dst_info_num=0,
)
else:
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
required_dst_info_num=int(msg[10].decode("ascii")),
)
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
required_dst_info_num=int(msg[10].decode("ascii")),
)
@dataclasses.dataclass
......@@ -278,7 +262,7 @@ class NixlKVManager(CommonKVManager):
for req in reqs_to_be_processed:
assert bootstrap_room == req.room
if req.is_dummy():
return []
continue
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
......@@ -346,8 +330,7 @@ class NixlKVManager(CommonKVManager):
), f"First message should be {GUARD}. Foreign traffic?"
waiting_req_bytes = waiting_req_bytes[1:]
room = waiting_req_bytes[0].decode("ascii")
if room == "None":
continue
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
room = int(room)
agent_name = waiting_req_bytes[4].decode("ascii")
......@@ -438,19 +421,6 @@ class NixlKVReceiver(CommonKVReceiver):
)
is_dummy = bootstrap_info["is_dummy"]
# TODO: just send "" for indices for dummy
if is_dummy:
# TODO: need to set success??
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
]
)
continue
# TODO: send_kv_args earlier
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
......@@ -473,7 +443,7 @@ class NixlKVReceiver(CommonKVReceiver):
self.kv_mgr.agent.get_agent_metadata(),
self.kv_mgr.agent.name.encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(),
kv_indices.tobytes() if not is_dummy else b"",
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment