"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "22b45304bf85a3c5281753d6b3259ccaf96e5085"
Unverified Commit 5ea04713 authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DistDGL] Memory optimization to reduce memory footprint of the Dist Graph...

[DistDGL] Memory optimization to reduce memory footprint of the Dist Graph partitioning pipeline. (#5130)

* Wrap np.argsort() in a function. This

Use a python wrapper for the np.argsort() function for better usage of systems memory.

* lintrunner patch.

* lintrunner patch.

* Changes to address code review comments.
parent 7ff04152
......@@ -815,6 +815,41 @@ def read_dataset(rank, world_size, id_lookup, params, schema_map):
)
def reorder_data(num_parts, world_size, data, key):
"""
Auxiliary function used to sort node and edge data for the input graph.
Parameters:
-----------
num_parts : int
total no. of partitions
world_size : int
total number of nodes used in this execution
data : dictionary
which is used to store the node and edge data for the input graph
key : string
specifies the column which is used to determine the sort order for
the remaining columns
Returns:
--------
dictionary
same as the input dictionary, but with reordered columns (values in
the dictionary), as per the np.argsort results on the column specified
by the ``key`` column
"""
for local_part_id in range(num_parts // world_size):
sorted_idx = data[key + "/" + str(local_part_id)].argsort()
for k, v in data.items():
tokens = k.split("/")
assert len(tokens) == 2
if tokens[1] == str(local_part_id):
data[k] = v[sorted_idx]
sorted_idx = None
gc.collect()
return data
def gen_dist_partitions(rank, world_size, params):
"""
Function which will be executed by all Gloo processes to begin execution of the pipeline.
......@@ -1022,16 +1057,11 @@ def gen_dist_partitions(rank, world_size, params):
memory_snapshot("DataShuffleComplete: ", rank)
# sort node_data by ntype
for local_part_id in range(params.num_parts // world_size):
idx = node_data[constants.NTYPE_ID + "/" + str(local_part_id)].argsort()
for k, v in node_data.items():
tokens = k.split("/")
assert len(tokens) == 2
if tokens[1] == str(local_part_id):
node_data[k] = v[idx]
idx = None
gc.collect()
node_data = reorder_data(
params.num_parts, world_size, node_data, constants.NTYPE_ID
)
logging.info(f"[Rank: {rank}] Sorted node_data by node_type")
memory_snapshot("NodeDataSortComplete: ", rank)
# resolve global_ids for nodes
assign_shuffle_global_nids_nodes(
......@@ -1068,18 +1098,12 @@ def gen_dist_partitions(rank, world_size, params):
][feature_idx]
memory_snapshot("ReorderNodeFeaturesComplete: ", rank)
# sort edge_data by etype
for local_part_id in range(params.num_parts // world_size):
sorted_idx = edge_data[
constants.ETYPE_ID + "/" + str(local_part_id)
].argsort()
for k, v in edge_data.items():
tokens = k.split("/")
assert len(tokens) == 2
if tokens[1] == str(local_part_id):
edge_data[k] = v[sorted_idx]
sorted_idx = None
gc.collect()
# Sort edge_data by etype
edge_data = reorder_data(
params.num_parts, world_size, edge_data, constants.ETYPE_ID
)
logging.info(f"[Rank: {rank}] Sorted edge_data by edge_type")
memory_snapshot("EdgeDataSortComplete: ", rank)
shuffle_global_eid_offsets = assign_shuffle_global_nids_edges(
rank, world_size, params.num_parts, edge_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment