Unverified Commit cf752077 authored by kylasa's avatar kylasa Committed by GitHub
Browse files

Distributed Lookup Service Robustness (#5387)

Handling corner cases in the distributed lookup service. When the get partition ids function is invoked with empty request. This is needed because we are using alltoall function in the get_partition_ids function.
parent 999c6245
......@@ -185,8 +185,13 @@ class DistLookupService:
]
# Find the process where global_nid --> partition-id(owner) is stored.
if len(global_nids) > 0:
ntype_ids, type_nids = self.id_map(global_nids)
ntype_ids, type_nids = ntype_ids.numpy(), type_nids.numpy()
else:
ntype_ids = np.array([], dtype=np.int64)
type_nids = np.array([], dtype=np.int64)
assert len(ntype_ids) == len(global_nids)
# For each node-type, the per-type-node-id <-> partition-id mappings are
......@@ -294,9 +299,11 @@ class DistLookupService:
# Order according to the requesting order.
# Owner_resp_list is the list of owner-ids for global_nids (function argument).
owner_ids = torch.cat(
[x for x in owner_resp_list if x is not None]
).numpy()
owner_ids = [x for x in owner_resp_list if x is not None]
if len(owner_ids) > 0:
owner_ids = torch.cat(owner_ids).numpy()
else:
owner_ids = np.array([], dtype=np.int64)
assert len(owner_ids) == len(global_nids)
global_nids_order = np.concatenate(indices_list)
......@@ -305,6 +312,7 @@ class DistLookupService:
global_nids_order = global_nids_order[sort_order_idx]
assert np.all(np.arange(len(global_nids)) == global_nids_order)
if len(owner_ids) > 0:
# Store the partition-ids for the current split
agg_partition_ids.append(owner_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment