Unverified Commit 9ce800d2 authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DistDGL][Optimizations]Rehash code to optimize for loop (#5224)

* Rehash code to optimize for loop

Reduced number of instructions in for loop, which exchanging edge features. This will reduce the number of times numpy's intersect1d is invoked (saving the runtime and memory overhead needs of numpy).

* Applying lintrunner patch to data_shuffle.py
parent e25f47de
......@@ -390,32 +390,33 @@ def exchange_feature(
tokens = feat_key.split("/")
assert len(tokens) == 3
local_feat_key = "/".join(tokens[:-1]) + "/" + str(local_part_id)
for idx in range(world_size):
# Get the partition ids for the range of global nids.
if feat_type == constants.STR_NODE_FEATURES:
# Retrieve the partition ids for the node features.
# Each partition id will be in the range [0, num_parts).
partid_slice = id_lookup.get_partition_ids(
np.arange(gid_start, gid_end, dtype=np.int64)
)
else:
# Edge data case.
# Ownership is determined by the destination node.
assert data is not None
global_eids = np.arange(gid_start, gid_end, dtype=np.int64)
# Now use `data` to extract destination nodes' global id
# and use that to get the ownership
common, idx1, idx2 = np.intersect1d(
data[constants.GLOBAL_EID], global_eids, return_indices=True
)
assert common.shape[0] == idx2.shape[0]
# Get the partition ids for the range of global nids.
if feat_type == constants.STR_NODE_FEATURES:
# Retrieve the partition ids for the node features.
# Each partition id will be in the range [0, num_parts).
partid_slice = id_lookup.get_partition_ids(
np.arange(gid_start, gid_end, dtype=np.int64)
)
else:
# Edge data case.
# Ownership is determined by the destination node.
assert data is not None
global_eids = np.arange(gid_start, gid_end, dtype=np.int64)
# Now use `data` to extract destination nodes' global id
# and use that to get the ownership
common, idx1, idx2 = np.intersect1d(
data[constants.GLOBAL_EID], global_eids, return_indices=True
)
assert common.shape[0] == idx2.shape[0]
global_dst_nids = data[constants.GLOBAL_DST_ID][idx1]
assert np.all(global_eids == data[constants.GLOBAL_EID][idx1])
partid_slice = id_lookup.get_partition_ids(global_dst_nids)
global_dst_nids = data[constants.GLOBAL_DST_ID][idx1]
assert np.all(global_eids == data[constants.GLOBAL_EID][idx1])
partid_slice = id_lookup.get_partition_ids(global_dst_nids)
for idx in range(world_size):
cond = partid_slice == (idx + local_part_id * world_size)
gids_per_partid = gids_feat[cond]
tids_per_partid = tids_feat[cond]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment