Unverified Commit 4cd0a685 authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DIST] Message size to retrieve SHUFFLE_GLOBAL_NIDs is resulting in very large...

[DIST] Message size to retrieve SHUFFLE_GLOBAL_NIDs is resulting in very large messages and resulting in killed process (#4790)

* Send out the message to the distributed lookup service in batches.

* Update function signature for allgather_sizes function call.

* Removed the unnecessary if statement .

* Removed logging.info message, which is not needed.
parent 2db4928e
......@@ -86,13 +86,43 @@ def lookup_shuffle_global_nids_edges(rank, world_size, edge_data, id_lookup, nod
dictionary where keys are column names and values are numpy arrays representing all the
edges present in the current graph partition
'''
# Make sure that the outgoing message size does not exceed 2GB in size.
# Even though gloo can handle upto 10GB size of data in the outgoing messages,
# it needs additional memory to store temporary information into the buffers which will increase
# the memory needs of the process.
MILLION = 1000 * 1000
BATCH_SIZE = 250 * MILLION
memory_snapshot("GlobalToShuffleIDMapBegin: ", rank)
node_list = np.concatenate([edge_data[constants.GLOBAL_SRC_ID], edge_data[constants.GLOBAL_DST_ID]])
shuffle_ids = id_lookup.get_shuffle_nids(node_list,
node_list = edge_data[constants.GLOBAL_SRC_ID]
# Determine the no. of times each process has to send alltoall messages.
all_sizes = allgather_sizes([node_list.shape[0]], world_size, return_sizes=True)
max_count = np.amax(all_sizes)
num_splits = max_count // BATCH_SIZE + 1
# Split the message into batches and send.
splits = np.array_split(node_list, num_splits)
shuffle_mappings = []
for item in splits:
shuffle_ids = id_lookup.get_shuffle_nids(item,
node_data[constants.GLOBAL_NID],
node_data[constants.SHUFFLE_GLOBAL_NID])
shuffle_mappings.append(shuffle_ids)
shuffle_ids = np.concatenate(shuffle_mappings)
assert shuffle_ids.shape[0] == node_list.shape[0]
edge_data[constants.SHUFFLE_GLOBAL_SRC_ID] = shuffle_ids
# Destination end points of edges are owned by the current node and therefore
# should have corresponding SHUFFLE_GLOBAL_NODE_IDs.
# Here retrieve SHUFFLE_GLOBAL_NODE_IDs for the destination end points of local edges.
uniq_ids, inverse_idx = np.unique(edge_data[constants.GLOBAL_DST_ID], return_inverse=True)
common, idx1, idx2 = np.intersect1d(uniq_ids, node_data[constants.GLOBAL_NID], assume_unique=True, return_indices=True)
assert len(common) == len(uniq_ids)
edge_data[constants.SHUFFLE_GLOBAL_DST_ID] = node_data[constants.SHUFFLE_GLOBAL_NID][idx2][inverse_idx]
assert len(edge_data[constants.SHUFFLE_GLOBAL_DST_ID]) == len(edge_data[constants.GLOBAL_DST_ID])
edge_data[constants.SHUFFLE_GLOBAL_SRC_ID], edge_data[constants.SHUFFLE_GLOBAL_DST_ID] = np.split(shuffle_ids, 2)
memory_snapshot("GlobalToShuffleIDMap_AfterLookupServiceCalls: ", rank)
return edge_data
......
......@@ -2,7 +2,7 @@ import numpy as np
import torch
import torch.distributed as dist
def allgather_sizes(send_data, world_size):
def allgather_sizes(send_data, world_size, return_sizes=False):
"""
Perform all gather on list lengths, used to compute prefix sums
to determine the offsets on each ranks. This is used to allocate
......@@ -14,6 +14,9 @@ def allgather_sizes(send_data, world_size):
Data on which allgather is performed.
world_size : integer
No. of processes configured for execution
return_sizes : bool
Boolean flag to indicate whether to return raw sizes from each process
or perform prefix sum on the raw sizes.
Returns :
---------
......@@ -30,6 +33,10 @@ def allgather_sizes(send_data, world_size):
#all_gather message
dist.all_gather(in_tensor, out_tensor)
# Return on the raw sizes from each process
if return_sizes:
return torch.cat(in_tensor).numpy()
#gather sizes in on array to return to the invoking function
rank_sizes = np.zeros(world_size + 1, dtype=np.int64)
count = rank_sizes[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment