Unverified Commit c459536b authored by dongmao zhang's avatar dongmao zhang Committed by GitHub
Browse files

[PD] bug fix: Update status if nixl receiver send a a dummy req. (#6720)

parent 535c8386
...@@ -53,39 +53,23 @@ class TransferInfo: ...@@ -53,39 +53,23 @@ class TransferInfo:
required_dst_info_num: int required_dst_info_num: int
def is_dummy(self): def is_dummy(self):
return self.endpoint == "" return self.dst_kv_indices.size == 0
@classmethod @classmethod
def from_zmq(cls, msg: List[bytes]): def from_zmq(cls, msg: List[bytes]):
if len(msg) == 1: return cls(
# dummy msg room=int(msg[0].decode("ascii")),
return cls( endpoint=msg[1].decode("ascii"),
room=int(msg[0].decode("ascii")), dst_port=int(msg[2].decode("ascii")),
endpoint="", agent_metadata=msg[3],
dst_port=0, agent_name=msg[4].decode("ascii"),
agent_metadata=b"", dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
agent_name="", dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_kv_ptrs=[], dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_kv_indices=np.array([], dtype=np.int64), dst_aux_index=int(msg[8].decode("ascii")),
dst_aux_ptrs=[], dst_gpu_id=int(msg[9].decode("ascii")),
dst_aux_index=0, required_dst_info_num=int(msg[10].decode("ascii")),
dst_gpu_id=0, )
required_dst_info_num=0,
)
else:
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
agent_name=msg[4].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
dst_aux_index=int(msg[8].decode("ascii")),
dst_gpu_id=int(msg[9].decode("ascii")),
required_dst_info_num=int(msg[10].decode("ascii")),
)
@dataclasses.dataclass @dataclasses.dataclass
...@@ -278,7 +262,7 @@ class NixlKVManager(CommonKVManager): ...@@ -278,7 +262,7 @@ class NixlKVManager(CommonKVManager):
for req in reqs_to_be_processed: for req in reqs_to_be_processed:
assert bootstrap_room == req.room assert bootstrap_room == req.room
if req.is_dummy(): if req.is_dummy():
return [] continue
peer_name = self._add_remote(req.agent_name, req.agent_metadata) peer_name = self._add_remote(req.agent_name, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice] chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
...@@ -346,8 +330,7 @@ class NixlKVManager(CommonKVManager): ...@@ -346,8 +330,7 @@ class NixlKVManager(CommonKVManager):
), f"First message should be {GUARD}. Foreign traffic?" ), f"First message should be {GUARD}. Foreign traffic?"
waiting_req_bytes = waiting_req_bytes[1:] waiting_req_bytes = waiting_req_bytes[1:]
room = waiting_req_bytes[0].decode("ascii") room = waiting_req_bytes[0].decode("ascii")
if room == "None":
continue
required_dst_info_num = int(waiting_req_bytes[10].decode("ascii")) required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
room = int(room) room = int(room)
agent_name = waiting_req_bytes[4].decode("ascii") agent_name = waiting_req_bytes[4].decode("ascii")
...@@ -438,19 +421,6 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -438,19 +421,6 @@ class NixlKVReceiver(CommonKVReceiver):
) )
is_dummy = bootstrap_info["is_dummy"] is_dummy = bootstrap_info["is_dummy"]
# TODO: just send "" for indices for dummy
if is_dummy:
# TODO: need to set success??
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
]
)
continue
# TODO: send_kv_args earlier # TODO: send_kv_args earlier
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
...@@ -473,7 +443,7 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -473,7 +443,7 @@ class NixlKVReceiver(CommonKVReceiver):
self.kv_mgr.agent.get_agent_metadata(), self.kv_mgr.agent.get_agent_metadata(),
self.kv_mgr.agent.name.encode("ascii"), self.kv_mgr.agent.name.encode("ascii"),
packed_kv_data_ptrs, packed_kv_data_ptrs,
kv_indices.tobytes(), kv_indices.tobytes() if not is_dummy else b"",
packed_aux_data_ptrs, packed_aux_data_ptrs,
str(aux_index).encode("ascii"), str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment