Unverified Commit bbc538d9 authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DistDGL][Robustness]Uneven distribution of input graph files for nodes/edges and features. (#5227)

* Uneven distribution of nodes/edges/features

To handle unevenly sized files for nodes/edges and feature files for nodes and edges, we have to synchronize before starting large no. of messages (either one large message or a burst of messages).

* Applying lintrunner patch.

* Removing tabspaces for lintrunner.

* lintrunner patch.

* removed issues introduced by the merge conflicts. Lots of code was repeated
parent 61b6edab
......@@ -127,6 +127,10 @@ def gen_node_data(
)
for ntype_id, ntype_name in ntid_ntype_map.items():
# No. of nodes in each process can differ significantly in lopsided distributions
# Synchronize on a per ntype basis
dist.barrier()
type_start, type_end = (
type_nid_dict[ntype_name][0][0],
type_nid_dict[ntype_name][-1][1],
......@@ -188,6 +192,9 @@ def exchange_edge_data(rank, world_size, num_parts, edge_data, id_lookup):
in the world.
"""
# Synchronize at the beginning of this function
dist.barrier()
# Prepare data for each rank in the cluster.
start = timer()
......@@ -576,6 +583,9 @@ def exchange_features(
for local_part_id in range(num_parts // world_size):
featdata_key = feature_data[feat_key]
# Synchronize for each feature
dist.barrier()
own_features, own_global_ids = exchange_feature(
rank,
data,
......@@ -694,6 +704,7 @@ def exchange_graph_data(
constants.STR_NODE_FEATURES,
None,
)
dist.barrier()
memory_snapshot("ShuffleNodeFeaturesComplete: ", rank)
logging.info(f"[Rank: {rank}] Done with node features exchange.")
......@@ -708,16 +719,19 @@ def exchange_graph_data(
constants.STR_EDGE_FEATURES,
edge_data,
)
dist.barrier()
logging.info(f"[Rank: {rank}] Done with edge features exchange.")
node_data = gen_node_data(
rank, world_size, num_parts, id_lookup, ntid_ntype_map, schema_map
)
dist.barrier()
memory_snapshot("NodeDataGenerationComplete: ", rank)
edge_data = exchange_edge_data(
rank, world_size, num_parts, edge_data, id_lookup
)
dist.barrier()
memory_snapshot("ShuffleEdgeDataComplete: ", rank)
return (
node_data,
......@@ -778,7 +792,6 @@ def read_dataset(rank, world_size, id_lookup, params, schema_map):
read by the current process. Note that each edge-type may have several edge-features.
"""
edge_features = {}
# node_tids, node_features, edge_datadict, edge_tids
(
node_tids,
node_features,
......@@ -795,6 +808,8 @@ def read_dataset(rank, world_size, id_lookup, params, schema_map):
params.num_parts,
schema_map,
)
# Synchronize so that everybody completes reading dataset from disk
dist.barrier()
logging.info(f"[Rank: {rank}] Done reading dataset {params.input_dir}")
edge_data = augment_edge_data(
......@@ -1065,6 +1080,8 @@ def gen_dist_partitions(rank, world_size, params):
memory_snapshot("NodeDataSortComplete: ", rank)
# resolve global_ids for nodes
# Synchronize before assigning shuffle-global-ids to nodes
dist.barrier()
assign_shuffle_global_nids_nodes(
rank, world_size, params.num_parts, node_data
)
......@@ -1106,10 +1123,13 @@ def gen_dist_partitions(rank, world_size, params):
logging.info(f"[Rank: {rank}] Sorted edge_data by edge_type")
memory_snapshot("EdgeDataSortComplete: ", rank)
# Synchronize before assigning shuffle-global-nids for edges end points.
dist.barrier()
shuffle_global_eid_offsets = assign_shuffle_global_nids_edges(
rank, world_size, params.num_parts, edge_data
)
logging.info(f"[Rank: {rank}] Done assigning global_ids to edges ...")
memory_snapshot("ShuffleGlobalID_Edges_Complete: ", rank)
# Shuffle edge features according to the edge order on each rank.
......@@ -1138,6 +1158,8 @@ def gen_dist_partitions(rank, world_size, params):
][feature_idx]
# determine global-ids for edge end-points
# Synchronize before retrieving shuffle-global-nids for edges end points.
dist.barrier()
edge_data = lookup_shuffle_global_nids_edges(
rank, world_size, params.num_parts, edge_data, id_lookup, node_data
)
......@@ -1163,6 +1185,9 @@ def gen_dist_partitions(rank, world_size, params):
graph_formats = params.graph_formats.split(",")
for local_part_id in range(params.num_parts // world_size):
# Synchronize for each local partition of the graph object.
dist.barrier()
num_edges = shuffle_global_eid_offsets[local_part_id]
node_count = len(
node_data[constants.NTYPE_ID + "/" + str(local_part_id)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment