"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "66c04855da9b5918dcf05b67304453d5f78e3df9"
Unverified Commit e25f47de authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DistDGL][Mem_Optimizations]Edge Ownership processes are computed on the fly when required. (#5225)

* Edge Ownership processes are computed on the fly when required.

Earlier we were storing Edge ownership processes after the dataset was retrieved from the disk. For massively large datasets, each node can handle upto 5 Billion edges, this means storing owner process-ids will consume 5 * 8 = 40GB. This memory will be hanging around until the edges are exchanged.

To reduce the memory footprint of the pipeline, we no longer store the ownership process-ids in the 'edge_data' dictionary after reading the dataset from the disk. Instead, we compute them on the fly at the time of exchanging edges.

Another optimization is not to send/receive all the messages in a one single large message. Instead we now split the total number edges into chunks, limited by 8 GB per node. And we iterate until all the chunks are exchanged.

Once all the edges are exchanged, as a sanity check, we compute the total number of edges in the system and compare it with the original value before edge shuffling, in a final assert statement before return the result to the caller.

* Applying lintrunner patch.
parent 1329be96
...@@ -164,7 +164,7 @@ def gen_node_data( ...@@ -164,7 +164,7 @@ def gen_node_data(
return local_node_data return local_node_data
def exchange_edge_data(rank, world_size, num_parts, edge_data): def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup):
""" """
Exchange edge_data among processes in the world. Exchange edge_data among processes in the world.
Prepare list of sliced data targeting each process and trigger Prepare list of sliced data targeting each process and trigger
...@@ -179,6 +179,7 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data): ...@@ -179,6 +179,7 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data):
edge_data : dictionary edge_data : dictionary
edge information, as a dicitonary which stores column names as keys and values edge information, as a dicitonary which stores column names as keys and values
as column data. This information is read from the edges.txt file. as column data. This information is read from the edges.txt file.
id_lookup : DistLookService object
Returns: Returns:
-------- --------
...@@ -189,52 +190,108 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data): ...@@ -189,52 +190,108 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data):
# Prepare data for each rank in the cluster. # Prepare data for each rank in the cluster.
start = timer() start = timer()
for local_part_id in range(num_parts // world_size):
input_list = [] CHUNK_SIZE = 100 * 1000 * 1000 # 100 * 8 * 5 = 1 * 4 = 8 GB/message/node
for idx in range(world_size): num_edges = edge_data[constants.GLOBAL_SRC_ID].shape[0]
send_idx = edge_data[constants.OWNER_PROCESS] == ( all_counts = allgather_sizes(
idx + local_part_id * world_size [num_edges], world_size, num_parts, return_sizes=True
) )
send_idx = send_idx.reshape( max_edges = np.amax(all_counts)
edge_data[constants.GLOBAL_SRC_ID].shape[0] all_edges = np.sum(all_counts)
num_chunks = (max_edges // CHUNK_SIZE) + (
0 if (max_edges % CHUNK_SIZE == 0) else 1
)
LOCAL_CHUNK_SIZE = (num_edges // num_chunks) + (
0 if (num_edges % num_chunks == 0) else 1
)
logging.info(
f"[Rank: {rank} Edge Data Shuffle - max_edges: {max_edges}, \
local_edges: {num_edges} and num_chunks: {num_chunks} \
Total edges: {all_edges} Local_CHUNK_SIZE: {LOCAL_CHUNK_SIZE}"
)
# Start sending the chunks to the rest of the processes
for local_part_id in range(num_parts // world_size):
local_src_ids = []
local_dst_ids = []
local_type_eids = []
local_etype_ids = []
local_eids = []
for chunk in range(num_chunks):
start = chunk * LOCAL_CHUNK_SIZE
end = (chunk + 1) * LOCAL_CHUNK_SIZE
logging.info(
f"[Rank: {rank}] EdgeData Shuffle: processing \
local_part_id: {local_part_id} and chunkid: {chunk}"
) )
filt_data = np.column_stack( cur_src_id = edge_data[constants.GLOBAL_SRC_ID][start:end]
( cur_dst_id = edge_data[constants.GLOBAL_DST_ID][start:end]
edge_data[constants.GLOBAL_SRC_ID][send_idx == 1], cur_type_eid = edge_data[constants.GLOBAL_TYPE_EID][start:end]
edge_data[constants.GLOBAL_DST_ID][send_idx == 1], cur_etype_id = edge_data[constants.ETYPE_ID][start:end]
edge_data[constants.GLOBAL_TYPE_EID][send_idx == 1], cur_eid = edge_data[constants.GLOBAL_EID][start:end]
edge_data[constants.ETYPE_ID][send_idx == 1],
edge_data[constants.GLOBAL_EID][send_idx == 1], input_list = []
owner_ids = id_lookup.get_partition_ids(cur_dst_id)
for idx in range(world_size):
send_idx = owner_ids == (idx + local_part_id * world_size)
send_idx = send_idx.reshape(cur_src_id.shape[0])
filt_data = np.column_stack(
(
cur_src_id[send_idx == 1],
cur_dst_id[send_idx == 1],
cur_type_eid[send_idx == 1],
cur_etype_id[send_idx == 1],
cur_eid[send_idx == 1],
)
) )
if filt_data.shape[0] <= 0:
input_list.append(torch.empty((0, 5), dtype=torch.int64))
else:
input_list.append(torch.from_numpy(filt_data))
# Now send newly formed chunk to others.
dist.barrier()
output_list = alltoallv_cpu(
rank, world_size, input_list, retain_nones=False
) )
if filt_data.shape[0] <= 0:
input_list.append(torch.empty((0, 5), dtype=torch.int64))
else:
input_list.append(torch.from_numpy(filt_data))
dist.barrier()
output_list = alltoallv_cpu(
rank, world_size, input_list, retain_nones=False
)
# Replace the values of the edge_data, with the received data from all the other processes. # Replace the values of the edge_data, with the received data from all the other processes.
rcvd_edge_data = torch.cat(output_list).numpy() rcvd_edge_data = torch.cat(output_list).numpy()
local_src_ids.append(rcvd_edge_data[:, 0])
local_dst_ids.append(rcvd_edge_data[:, 1])
local_type_eids.append(rcvd_edge_data[:, 2])
local_etype_ids.append(rcvd_edge_data[:, 3])
local_eids.append(rcvd_edge_data[:, 4])
edge_data[ edge_data[
constants.GLOBAL_SRC_ID + "/" + str(local_part_id) constants.GLOBAL_SRC_ID + "/" + str(local_part_id)
] = rcvd_edge_data[:, 0] ] = np.concatenate(local_src_ids)
edge_data[ edge_data[
constants.GLOBAL_DST_ID + "/" + str(local_part_id) constants.GLOBAL_DST_ID + "/" + str(local_part_id)
] = rcvd_edge_data[:, 1] ] = np.concatenate(local_dst_ids)
edge_data[ edge_data[
constants.GLOBAL_TYPE_EID + "/" + str(local_part_id) constants.GLOBAL_TYPE_EID + "/" + str(local_part_id)
] = rcvd_edge_data[:, 2] ] = np.concatenate(local_type_eids)
edge_data[ edge_data[
constants.ETYPE_ID + "/" + str(local_part_id) constants.ETYPE_ID + "/" + str(local_part_id)
] = rcvd_edge_data[:, 3] ] = np.concatenate(local_etype_ids)
edge_data[ edge_data[
constants.GLOBAL_EID + "/" + str(local_part_id) constants.GLOBAL_EID + "/" + str(local_part_id)
] = rcvd_edge_data[:, 4] ] = np.concatenate(local_eids)
# Check if the data was exchanged correctly
local_edge_count = 0
for local_part_id in range(num_parts // world_size):
local_edge_count += edge_data[
constants.GLOBAL_SRC_ID + "/" + str(local_part_id)
].shape[0]
shuffle_edge_counts = allgather_sizes(
[local_edge_count], world_size, num_parts, return_sizes=True
)
shuffle_edge_total = np.sum(shuffle_edge_counts)
assert shuffle_edge_total == all_edges
end = timer() end = timer()
logging.info( logging.info(
...@@ -242,7 +299,6 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data): ...@@ -242,7 +299,6 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data):
) )
# Clean up. # Clean up.
edge_data.pop(constants.OWNER_PROCESS)
edge_data.pop(constants.GLOBAL_SRC_ID) edge_data.pop(constants.GLOBAL_SRC_ID)
edge_data.pop(constants.GLOBAL_DST_ID) edge_data.pop(constants.GLOBAL_DST_ID)
edge_data.pop(constants.GLOBAL_TYPE_EID) edge_data.pop(constants.GLOBAL_TYPE_EID)
...@@ -659,7 +715,9 @@ def exchange_graph_data( ...@@ -659,7 +715,9 @@ def exchange_graph_data(
) )
memory_snapshot("NodeDataGenerationComplete: ", rank) memory_snapshot("NodeDataGenerationComplete: ", rank)
edge_data = exchange_edge_data(rank, world_size, num_parts, edge_data) edge_data = exchange_edge_data(
rank, world_size, num_parts, edge_data, id_lookup
)
memory_snapshot("ShuffleEdgeDataComplete: ", rank) memory_snapshot("ShuffleEdgeDataComplete: ", rank)
return ( return (
node_data, node_data,
......
...@@ -329,10 +329,6 @@ def augment_edge_data( ...@@ -329,10 +329,6 @@ def augment_edge_data(
assert global_eids.shape[0] == edge_data[constants.ETYPE_ID].shape[0] assert global_eids.shape[0] == edge_data[constants.ETYPE_ID].shape[0]
edge_data[constants.GLOBAL_EID] = global_eids edge_data[constants.GLOBAL_EID] = global_eids
# assign the owner process/rank for each edge
edge_data[constants.OWNER_PROCESS] = lookup_service.get_partition_ids(
edge_data[constants.GLOBAL_DST_ID]
)
return edge_data return edge_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment