Unverified Commit cf752077 authored by kylasa's avatar kylasa Committed by GitHub
Browse files

Distributed Lookup Service Robustness (#5387)

Handling corner cases in the distributed lookup service. When the get partition ids function is invoked with empty request. This is needed because we are using alltoall function in the get_partition_ids function.
parent 999c6245
...@@ -185,8 +185,13 @@ class DistLookupService: ...@@ -185,8 +185,13 @@ class DistLookupService:
] ]
# Find the process where global_nid --> partition-id(owner) is stored. # Find the process where global_nid --> partition-id(owner) is stored.
if len(global_nids) > 0:
ntype_ids, type_nids = self.id_map(global_nids) ntype_ids, type_nids = self.id_map(global_nids)
ntype_ids, type_nids = ntype_ids.numpy(), type_nids.numpy() ntype_ids, type_nids = ntype_ids.numpy(), type_nids.numpy()
else:
ntype_ids = np.array([], dtype=np.int64)
type_nids = np.array([], dtype=np.int64)
assert len(ntype_ids) == len(global_nids) assert len(ntype_ids) == len(global_nids)
# For each node-type, the per-type-node-id <-> partition-id mappings are # For each node-type, the per-type-node-id <-> partition-id mappings are
...@@ -294,9 +299,11 @@ class DistLookupService: ...@@ -294,9 +299,11 @@ class DistLookupService:
# Order according to the requesting order. # Order according to the requesting order.
# Owner_resp_list is the list of owner-ids for global_nids (function argument). # Owner_resp_list is the list of owner-ids for global_nids (function argument).
owner_ids = torch.cat( owner_ids = [x for x in owner_resp_list if x is not None]
[x for x in owner_resp_list if x is not None] if len(owner_ids) > 0:
).numpy() owner_ids = torch.cat(owner_ids).numpy()
else:
owner_ids = np.array([], dtype=np.int64)
assert len(owner_ids) == len(global_nids) assert len(owner_ids) == len(global_nids)
global_nids_order = np.concatenate(indices_list) global_nids_order = np.concatenate(indices_list)
...@@ -305,6 +312,7 @@ class DistLookupService: ...@@ -305,6 +312,7 @@ class DistLookupService:
global_nids_order = global_nids_order[sort_order_idx] global_nids_order = global_nids_order[sort_order_idx]
assert np.all(np.arange(len(global_nids)) == global_nids_order) assert np.all(np.arange(len(global_nids)) == global_nids_order)
if len(owner_ids) > 0:
# Store the partition-ids for the current split # Store the partition-ids for the current split
agg_partition_ids.append(owner_ids) agg_partition_ids.append(owner_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment