"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "ad0daff1bbc8b148a6d96df6c9b1d9b9c1b6adad"
Unverified Commit 6f36dd63 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

Cleanup some code in graph_services.py (#3238)

* Fix bug

* Fix

* Fix

* upd

* Merge some code

* lint
parent 1b350e93
...@@ -382,6 +382,34 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): ...@@ -382,6 +382,34 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
sampled_graph = merge_graphs(res_list, g.number_of_nodes()) sampled_graph = merge_graphs(res_list, g.number_of_nodes())
return sampled_graph return sampled_graph
def _frontier_to_heterogeneous_graph(g, frontier, gpb):
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
assert len(eid) > 0
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)
data_dict = dict()
edge_ids = {}
for etid in range(len(g.etypes)):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)
for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None, replace=False): def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None, replace=False):
"""Sample from the neighbors of the given nodes from a distributed graph. """Sample from the neighbors of the given nodes from a distributed graph.
...@@ -460,31 +488,7 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No ...@@ -460,31 +488,7 @@ def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=No
etype_field, fanout, edge_dir, prob, replace) etype_field, fanout, edge_dir, prob, replace)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if len(gpb.etypes) > 1: if len(gpb.etypes) > 1:
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID]) return _frontier_to_heterogeneous_graph(g, frontier, gpb)
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)
data_dict = dict()
edge_ids = {}
for etid in range(len(g.etypes)):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)
for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
else: else:
return frontier return frontier
...@@ -561,32 +565,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False): ...@@ -561,32 +565,7 @@ def sample_neighbors(g, nodes, fanout, edge_dir='in', prob=None, replace=False):
fanout, edge_dir, prob, replace) fanout, edge_dir, prob, replace)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access) frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if len(gpb.etypes) > 1: if len(gpb.etypes) > 1:
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID]) return _frontier_to_heterogeneous_graph(g, frontier, gpb)
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
assert len(eid) > 0
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)
data_dict = dict()
edge_ids = {}
for etid in range(len(g.etypes)):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)
for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
else: else:
return frontier return frontier
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment