Unverified Commit 9ce800d2 authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DistDGL][Optimizations]Rehash code to optimize for loop (#5224)

* Rehash code to optimize for loop

Reduced number of instructions in for loop, which exchanging edge features. This will reduce the number of times numpy's intersect1d is invoked (saving the runtime and memory overhead needs of numpy).

* Applying lintrunner patch to data_shuffle.py
parent e25f47de
...@@ -390,8 +390,8 @@ def exchange_feature( ...@@ -390,8 +390,8 @@ def exchange_feature(
tokens = feat_key.split("/") tokens = feat_key.split("/")
assert len(tokens) == 3 assert len(tokens) == 3
local_feat_key = "/".join(tokens[:-1]) + "/" + str(local_part_id) local_feat_key = "/".join(tokens[:-1]) + "/" + str(local_part_id)
for idx in range(world_size):
# Get the partition ids for the range of global nids. # Get the partition ids for the range of global nids.
if feat_type == constants.STR_NODE_FEATURES: if feat_type == constants.STR_NODE_FEATURES:
# Retrieve the partition ids for the node features. # Retrieve the partition ids for the node features.
...@@ -416,6 +416,7 @@ def exchange_feature( ...@@ -416,6 +416,7 @@ def exchange_feature(
assert np.all(global_eids == data[constants.GLOBAL_EID][idx1]) assert np.all(global_eids == data[constants.GLOBAL_EID][idx1])
partid_slice = id_lookup.get_partition_ids(global_dst_nids) partid_slice = id_lookup.get_partition_ids(global_dst_nids)
for idx in range(world_size):
cond = partid_slice == (idx + local_part_id * world_size) cond = partid_slice == (idx + local_part_id * world_size)
gids_per_partid = gids_feat[cond] gids_per_partid = gids_feat[cond]
tids_per_partid = tids_feat[cond] tids_per_partid = tids_feat[cond]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment