"...en/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "078df46bc9a99178a9a744b872899990353769a4"
Unverified Commit ab1b2811 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] etype is not guaranteed to be sorted (#4156)

parent 4d3c01d6
...@@ -70,7 +70,7 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr ...@@ -70,7 +70,7 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr
return global_src, global_dst, global_eids return global_src, global_dst, global_eids
def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field, def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
fan_out, edge_dir, prob, replace): fan_out, edge_dir, prob, replace, etype_sorted=False):
""" Sample from local partition. """ Sample from local partition.
The input nodes use global IDs. We need to map the global node IDs to local node IDs, The input nodes use global IDs. We need to map the global node IDs to local node IDs,
...@@ -80,13 +80,10 @@ def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field, ...@@ -80,13 +80,10 @@ def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
""" """
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid) local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype) local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes
# DistGraph's edges are sorted by default according to
# graph partition mechanism.
sampled_graph = local_sample_etype_neighbors( sampled_graph = local_sample_etype_neighbors(
local_g, local_ids, etype_field, fan_out, edge_dir, prob, replace, local_g, local_ids, etype_field, fan_out, edge_dir, prob, replace,
etype_sorted=True, _dist_training=True) etype_sorted=etype_sorted, _dist_training=True)
global_nid_mapping = local_g.ndata[NID] global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges() src, dst = sampled_graph.edges()
global_src, global_dst = F.gather_row(global_nid_mapping, src), \ global_src, global_dst = F.gather_row(global_nid_mapping, src), \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment