Unverified Commit 02fe316d authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix][Hetero] fix to_hetero when metagraph is given (#873)

* fix to_hetero when metagraph is given

* minor fix

* add more check
parent f9c0217d
...@@ -272,6 +272,13 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -272,6 +272,13 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
the type id, which which can be used to retrieve the type names stored the type id, which which can be used to retrieve the type names stored
in the given ``ntypes`` and ``etypes`` arguments. in the given ``ntypes`` and ``etypes`` arguments.
The function will automatically distinguish edge types that have the same given
type IDs but different src and dst type IDs. For example, we allow both edges A and B
to have the same type ID 3, but one has (0, 1) and the other as (2, 3) as the
(src, dst) type IDs. In this case, the function will "split" edge type 3 into two types:
(0, ty_A, 1) and (2, ty_B, 3). In another word, these two edges share the same edge
type name, but can be distinguished by a canonical edge type tuple.
Examples Examples
-------- --------
TBD TBD
...@@ -339,27 +346,30 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph ...@@ -339,27 +346,30 @@ def to_hetero(G, ntypes, etypes, ntype_field=NTYPE, etype_field=ETYPE, metagraph
dst = F.asnumpy(dst) dst = F.asnumpy(dst)
src_local = ntype_local_ids[src] src_local = ntype_local_ids[src]
dst_local = ntype_local_ids[dst] dst_local = ntype_local_ids[dst]
srctype_ids = ntype_ids[src] # a 2D tensor of shape (E, 3). Each row represents the (stid, etid, dtid) tuple.
dsttype_ids = ntype_ids[dst] edge_ctids = np.stack([ntype_ids[src], etype_ids, ntype_ids[dst]], 1)
canon_etype_ids = np.stack([srctype_ids, etype_ids, dsttype_ids], 1)
# infer metagraph and canonical edge types
# infer metagraph # No matter which branch it takes, the code will generate a 2D tensor of shape (E_m, 3),
# E_m is the set of all possible canonical edge tuples. Each row represents the
# (stid, dtid, dtid) tuple. We then compute a 2D tensor of shape (E, E_m) using the
# above ``edge_ctids`` matrix. Each element i,j indicates whether the edge i is of the
# canonical edge type j. We can then group the edges of the same type together.
if metagraph is None: if metagraph is None:
canonical_etids, _, etype_remapped = \ canonical_etids, _, etype_remapped = \
utils.make_invmap(list(tuple(_) for _ in canon_etype_ids), False) utils.make_invmap(list(tuple(_) for _ in edge_ctids), False)
etype_mask = (etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None]) etype_mask = (etype_remapped[None, :] == np.arange(len(canonical_etids))[:, None])
else: else:
ntypes_invmap = {nt: i for i, nt in enumerate(ntypes)} ntypes_invmap = {nt: i for i, nt in enumerate(ntypes)}
etypes_invmap = {et: i for i, et in enumerate(etypes)} etypes_invmap = {et: i for i, et in enumerate(etypes)}
canonical_etids = [] canonical_etids = []
etype_remapped = np.zeros(etype_ids)
for i, (srctype, dsttype, etype) in enumerate(metagraph.edges(keys=True)): for i, (srctype, dsttype, etype) in enumerate(metagraph.edges(keys=True)):
srctype_id = ntypes_invmap[srctype] srctype_id = ntypes_invmap[srctype]
etype_id = etypes_invmap[etype] etype_id = etypes_invmap[etype]
dsttype_id = ntypes_invmap[dsttype] dsttype_id = ntypes_invmap[dsttype]
canonical_etids.append((srctype_id, etype_id, dsttype_id)) canonical_etids.append((srctype_id, etype_id, dsttype_id))
canonical_etids = np.array(canonical_etids) canonical_etids = np.array(canonical_etids)
etype_mask = (canon_etype_ids[None, :] == canonical_etids[:, None]).all(2) etype_mask = (edge_ctids[None, :] == canonical_etids[:, None]).all(2)
edge_groups = [etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))] edge_groups = [etype_mask[i].nonzero()[0] for i in range(len(canonical_etids))]
rel_graphs = [] rel_graphs = []
......
...@@ -60,6 +60,9 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -60,6 +60,9 @@ class UnitGraph::COO : public BaseHeteroGraph {
public: public:
COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src, IdArray dst) COO(GraphPtr metagraph, int64_t num_src, int64_t num_dst, IdArray src, IdArray dst)
: BaseHeteroGraph(metagraph) { : BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
adj_ = aten::COOMatrix{num_src, num_dst, src, dst}; adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
} }
...@@ -67,6 +70,9 @@ class UnitGraph::COO : public BaseHeteroGraph { ...@@ -67,6 +70,9 @@ class UnitGraph::COO : public BaseHeteroGraph {
IdArray src, IdArray dst, bool is_multigraph) IdArray src, IdArray dst, bool is_multigraph)
: BaseHeteroGraph(metagraph), : BaseHeteroGraph(metagraph),
is_multigraph_(is_multigraph) { is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(src));
CHECK(aten::IsValidIdArray(dst));
CHECK_EQ(src->shape[0], dst->shape[0]) << "Input arrays should have the same length.";
adj_ = aten::COOMatrix{num_src, num_dst, src, dst}; adj_ = aten::COOMatrix{num_src, num_dst, src, dst};
} }
...@@ -326,12 +332,22 @@ class UnitGraph::CSR : public BaseHeteroGraph { ...@@ -326,12 +332,22 @@ class UnitGraph::CSR : public BaseHeteroGraph {
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids) IdArray indptr, IdArray indices, IdArray edge_ids)
: BaseHeteroGraph(metagraph) { : BaseHeteroGraph(metagraph) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
} }
CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst, CSR(GraphPtr metagraph, int64_t num_src, int64_t num_dst,
IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph) IdArray indptr, IdArray indices, IdArray edge_ids, bool is_multigraph)
: BaseHeteroGraph(metagraph), is_multigraph_(is_multigraph) { : BaseHeteroGraph(metagraph), is_multigraph_(is_multigraph) {
CHECK(aten::IsValidIdArray(indptr));
CHECK(aten::IsValidIdArray(indices));
CHECK(aten::IsValidIdArray(edge_ids));
CHECK_EQ(indices->shape[0], edge_ids->shape[0])
<< "indices and edge id arrays should have the same length";
adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids}; adj_ = aten::CSRMatrix{num_src, num_dst, indptr, indices, edge_ids};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment