"docs/vscode:/vscode.git/clone" did not exist on "b8cf84a3f902550937255c5b28b39827ba52beb6"
Unverified Commit 5f0b61b8 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] enable access DistGraph.edges via canonical etype (#4814)

* [Dist] enable access DistGraph.edges via canonical etype

* refine code

* refine test

* refine code
parent 889798fe
...@@ -157,14 +157,16 @@ class HeteroNodeView(object): ...@@ -157,14 +157,16 @@ class HeteroNodeView(object):
return NodeSpace(data=NodeDataView(self._graph, key)) return NodeSpace(data=NodeDataView(self._graph, key))
class HeteroEdgeView(object): class HeteroEdgeView(object):
"""A NodeView class to act as G.nodes for a DistGraph.""" """An EdgeView class to act as G.edges for a DistGraph."""
__slots__ = ['_graph'] __slots__ = ['_graph']
def __init__(self, graph): def __init__(self, graph):
self._graph = graph self._graph = graph
def __getitem__(self, key): def __getitem__(self, key):
assert isinstance(key, str) assert isinstance(key, str) or (
isinstance(key, tuple) and len(key) == 3
), f"Expect edge type in string or triplet of string, but got {key}."
return EdgeSpace(data=EdgeDataView(self._graph, key)) return EdgeSpace(data=EdgeDataView(self._graph, key))
class NodeDataView(MutableMapping): class NodeDataView(MutableMapping):
......
...@@ -549,12 +549,10 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb): ...@@ -549,12 +549,10 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
data_dict = dict() data_dict = dict()
edge_ids = {} edge_ids = {}
for etid in range(len(g.etypes)): for etid, etype in enumerate(g.canonical_etypes):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0: if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = ( data_dict[etype] = (
F.boolean_mask(src, type_idx), F.boolean_mask(src, type_idx),
F.boolean_mask(dst, type_idx), F.boolean_mask(dst, type_idx),
) )
...@@ -638,9 +636,19 @@ def sample_etype_neighbors( ...@@ -638,9 +636,19 @@ def sample_etype_neighbors(
A sampled subgraph containing only the sampled neighboring edges. It is on CPU. A sampled subgraph containing only the sampled neighboring edges. It is on CPU.
""" """
if isinstance(fanout, int): if isinstance(fanout, int):
fanout = F.full_1d(len(g.etypes), fanout, F.int64, F.cpu()) fanout = F.full_1d(len(g.canonical_etypes), fanout, F.int64, F.cpu())
else: else:
fanout = F.tensor([fanout[etype] for etype in g.etypes], dtype=F.int64) etype_ids = {etype: i for i, etype in enumerate(g.canonical_etypes)}
fanout_array = [None] * len(g.canonical_etypes)
for etype, v in fanout.items():
c_etype = g.to_canonical_etype(etype)
fanout_array[etype_ids[c_etype]] = v
assert all(v is not None for v in fanout_array), (
"Not all etypes have valid fanout. Please make sure passed-in "
"fanout in dict includes all the etypes in graph. Passed-in "
f"fanout: {fanout}, graph etypes: {g.canonical_etypes}."
)
fanout = F.tensor(fanout_array, dtype=F.int64)
gpb = g.get_partition_book() gpb = g.get_partition_book()
if isinstance(nodes, dict): if isinstance(nodes, dict):
...@@ -667,7 +675,7 @@ def sample_etype_neighbors( ...@@ -667,7 +675,7 @@ def sample_etype_neighbors(
g.edges[etype].data[prob].kvstore_key g.edges[etype].data[prob].kvstore_key
if prob in g.edges[etype].data if prob in g.edges[etype].data
else "" else ""
for etype in g.etypes for etype in g.canonical_etypes
] ]
else: else:
_prob = None _prob = None
...@@ -690,7 +698,7 @@ def sample_etype_neighbors( ...@@ -690,7 +698,7 @@ def sample_etype_neighbors(
g.edges[etype].data[prob].local_partition g.edges[etype].data[prob].local_partition
if prob in g.edges[etype].data if prob in g.edges[etype].data
else None else None
for etype in g.etypes for etype in g.canonical_etypes
] ]
return _sample_etype_neighbors( return _sample_etype_neighbors(
local_g, local_g,
......
...@@ -730,8 +730,14 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges): ...@@ -730,8 +730,14 @@ def check_dist_graph_hetero(g, num_clients, num_nodes, num_edges):
# Test reading edge data # Test reading edge data
eids = F.arange(0, int(g.number_of_edges("r1") / 2)) eids = F.arange(0, int(g.number_of_edges("r1") / 2))
feats1 = g.edges["r1"].data["feat"][eids] # access via etype
feats = F.squeeze(feats1, 1) feats = g.edges["r1"].data["feat"][eids]
feats = F.squeeze(feats, 1)
assert np.all(F.asnumpy(feats == eids))
# access via canonical etype
c_etype = g.to_canonical_etype("r1")
feats = g.edges[c_etype].data["feat"][eids]
feats = F.squeeze(feats, 1)
assert np.all(F.asnumpy(feats == eids)) assert np.all(F.asnumpy(feats == eids))
# Test edge_subgraph # Test edge_subgraph
......
...@@ -470,7 +470,7 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, graph_formats=No ...@@ -470,7 +470,7 @@ def check_rpc_hetero_etype_sampling_shuffle(tmpdir, num_server, graph_formats=No
time.sleep(1) time.sleep(1)
pserver_list.append(p) pserver_list.append(p)
fanout = 3 fanout = {etype: 3 for etype in g.canonical_etypes}
etype_sorted = False etype_sorted = False
if graph_formats is not None: if graph_formats is not None:
etype_sorted = 'csc' in graph_formats or 'csr' in graph_formats etype_sorted = 'csc' in graph_formats or 'csr' in graph_formats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment