"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "17088a689e66b5d8f678725fb1f4c7b97f0b4e77"
Unverified Commit 9ce800d2 authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DistDGL][Optimizations]Rehash code to optimize for loop (#5224)

* Rehash code to optimize for loop

Reduced number of instructions in for loop, which exchanging edge features. This will reduce the number of times numpy's intersect1d is invoked (saving the runtime and memory overhead needs of numpy).

* Applying lintrunner patch to data_shuffle.py
parent e25f47de
...@@ -390,32 +390,33 @@ def exchange_feature( ...@@ -390,32 +390,33 @@ def exchange_feature(
tokens = feat_key.split("/") tokens = feat_key.split("/")
assert len(tokens) == 3 assert len(tokens) == 3
local_feat_key = "/".join(tokens[:-1]) + "/" + str(local_part_id) local_feat_key = "/".join(tokens[:-1]) + "/" + str(local_part_id)
for idx in range(world_size): # Get the partition ids for the range of global nids.
# Get the partition ids for the range of global nids. if feat_type == constants.STR_NODE_FEATURES:
if feat_type == constants.STR_NODE_FEATURES: # Retrieve the partition ids for the node features.
# Retrieve the partition ids for the node features. # Each partition id will be in the range [0, num_parts).
# Each partition id will be in the range [0, num_parts). partid_slice = id_lookup.get_partition_ids(
partid_slice = id_lookup.get_partition_ids( np.arange(gid_start, gid_end, dtype=np.int64)
np.arange(gid_start, gid_end, dtype=np.int64) )
) else:
else: # Edge data case.
# Edge data case. # Ownership is determined by the destination node.
# Ownership is determined by the destination node. assert data is not None
assert data is not None global_eids = np.arange(gid_start, gid_end, dtype=np.int64)
global_eids = np.arange(gid_start, gid_end, dtype=np.int64)
# Now use `data` to extract destination nodes' global id
# Now use `data` to extract destination nodes' global id # and use that to get the ownership
# and use that to get the ownership common, idx1, idx2 = np.intersect1d(
common, idx1, idx2 = np.intersect1d( data[constants.GLOBAL_EID], global_eids, return_indices=True
data[constants.GLOBAL_EID], global_eids, return_indices=True )
) assert common.shape[0] == idx2.shape[0]
assert common.shape[0] == idx2.shape[0]
global_dst_nids = data[constants.GLOBAL_DST_ID][idx1] global_dst_nids = data[constants.GLOBAL_DST_ID][idx1]
assert np.all(global_eids == data[constants.GLOBAL_EID][idx1]) assert np.all(global_eids == data[constants.GLOBAL_EID][idx1])
partid_slice = id_lookup.get_partition_ids(global_dst_nids) partid_slice = id_lookup.get_partition_ids(global_dst_nids)
for idx in range(world_size):
cond = partid_slice == (idx + local_part_id * world_size) cond = partid_slice == (idx + local_part_id * world_size)
gids_per_partid = gids_feat[cond] gids_per_partid = gids_feat[cond]
tids_per_partid = tids_feat[cond] tids_per_partid = tids_feat[cond]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment