Unverified Commit c459536b authored by dongmao zhang's avatar dongmao zhang Committed by GitHub
Browse files

[PD] bug fix: Update status if nixl receiver send a a dummy req. (#6720)

parent 535c8386
...@@ -53,26 +53,10 @@ class TransferInfo: ...@@ -53,26 +53,10 @@ class TransferInfo:
required_dst_info_num: int required_dst_info_num: int
def is_dummy(self): def is_dummy(self):
return self.endpoint == "" return self.dst_kv_indices.size == 0
@classmethod @classmethod
def from_zmq(cls, msg: List[bytes]): def from_zmq(cls, msg: List[bytes]):
if len(msg) == 1:
# dummy msg
return cls(
room=int(msg[0].decode("ascii")),
endpoint="",
dst_port=0,
agent_metadata=b"",
agent_name="",
dst_kv_ptrs=[],
dst_kv_indices=np.array([], dtype=np.int64),
dst_aux_ptrs=[],
dst_aux_index=0,
dst_gpu_id=0,
required_dst_info_num=0,
)
else:
return cls( return cls(
room=int(msg[0].decode("ascii")), room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"), endpoint=msg[1].decode("ascii"),
...@@ -278,7 +262,7 @@ class NixlKVManager(CommonKVManager): ...@@ -278,7 +262,7 @@ class NixlKVManager(CommonKVManager):
for req in reqs_to_be_processed: for req in reqs_to_be_processed:
assert bootstrap_room == req.room assert bootstrap_room == req.room
if req.is_dummy(): if req.is_dummy():
return [] continue
peer_name = self._add_remote(req.agent_name, req.agent_metadata) peer_name = self._add_remote(req.agent_name, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice] chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
...@@ -346,8 +330,7 @@ class NixlKVManager(CommonKVManager): ...@@ -346,8 +330,7 @@ class NixlKVManager(CommonKVManager):
), f"First message should be {GUARD}. Foreign traffic?" ), f"First message should be {GUARD}. Foreign traffic?"
waiting_req_bytes = waiting_req_bytes[1:] waiting_req_bytes = waiting_req_bytes[1:]
room = waiting_req_bytes[0].decode("ascii") room = waiting_req_bytes[0].decode("ascii")
if room == "None":
continue
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii")) required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
room = int(room) room = int(room)
agent_name = waiting_req_bytes[4].decode("ascii") agent_name = waiting_req_bytes[4].decode("ascii")
...@@ -438,19 +421,6 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -438,19 +421,6 @@ class NixlKVReceiver(CommonKVReceiver):
) )
is_dummy = bootstrap_info["is_dummy"] is_dummy = bootstrap_info["is_dummy"]
# TODO: just send "" for indices for dummy
if is_dummy:
# TODO: need to set success??
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
]
)
continue
# TODO: send_kv_args earlier # TODO: send_kv_args earlier
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
...@@ -473,7 +443,7 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -473,7 +443,7 @@ class NixlKVReceiver(CommonKVReceiver):
self.kv_mgr.agent.get_agent_metadata(), self.kv_mgr.agent.get_agent_metadata(),
self.kv_mgr.agent.name.encode("ascii"), self.kv_mgr.agent.name.encode("ascii"),
packed_kv_data_ptrs, packed_kv_data_ptrs,
kv_indices.tobytes(), kv_indices.tobytes() if not is_dummy else b"",
packed_aux_data_ptrs, packed_aux_data_ptrs,
str(aux_index).encode("ascii"), str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment