Commit baf5906b authored by GaiYu0's avatar GaiYu0
Browse files

(non)-backtracking line graph

parent 5119a504
...@@ -1247,7 +1247,7 @@ class DGLGraph(object): ...@@ -1247,7 +1247,7 @@ class DGLGraph(object):
""" """
return self._graph.adjacency_matrix() return self._graph.adjacency_matrix()
def incidence_matrix(self, oriented=False): def incidence_matrix(self, oriented=False, sorted=False):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
Returns Returns
...@@ -1255,9 +1255,9 @@ class DGLGraph(object): ...@@ -1255,9 +1255,9 @@ class DGLGraph(object):
utils.CtxCachedObject utils.CtxCachedObject
An object that returns tensor given context. An object that returns tensor given context.
""" """
return self._graph.incidence_matrix(oriented) return self._graph.incidence_matrix(oriented, sorted)
def line_graph(self): def line_graph(self, backtracking=True, sorted=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
Returns Returns
...@@ -1265,9 +1265,9 @@ class DGLGraph(object): ...@@ -1265,9 +1265,9 @@ class DGLGraph(object):
DGLGraph DGLGraph
The line graph of this graph. The line graph of this graph.
""" """
return DGLGraph(self._graph.line_graph()) return DGLGraph(self._graph.line_graph(backtracking, sorted))
def _line_graph(self, backtracking=False): def _line_graph(self, backtracking=True, sorted=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
Returns Returns
...@@ -1275,7 +1275,7 @@ class DGLGraph(object): ...@@ -1275,7 +1275,7 @@ class DGLGraph(object):
DGLGraph DGLGraph
The line graph of this graph. The line graph of this graph.
""" """
return DGLGraph(self._graph._line_graph(backtracking)) return DGLGraph(self._graph._line_graph(backtracking, sorted))
def _get_repr(attr_dict): def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict: if len(attr_dict) == 1 and __REPR__ in attr_dict:
......
...@@ -408,17 +408,26 @@ class GraphIndex(object): ...@@ -408,17 +408,26 @@ class GraphIndex(object):
self._cache['adj'] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx)) self._cache['adj'] = utils.CtxCachedObject(lambda ctx: F.to_context(mat, ctx))
return self._cache['adj'] return self._cache['adj']
def incidence_matrix(self, oriented=False): def incidence_matrix(self, oriented=False, sorted=False):
"""Return the incidence matrix representation of this graph. """Return the incidence matrix representation of this graph.
Parameters
----------
oriented : bool, optional (default=False)
Whether the returned incidence matrix is oriented.
sorted : bool, optional (default=False)
If true, nodes in L(G) are sorted as pairs.
If False, nodes in L(G) are ordered by their edge id's in G.
Returns Returns
------- -------
utils.CtxCachedObject utils.CtxCachedObject
An object that returns tensor given context. An object that returns tensor given context.
""" """
key = ('oriented ' if oriented else '') + 'incidence matrix' key = ('oriented ' if oriented else '') + \
('sorted ' if sorted else '') + 'incidence matrix'
if not key in self._cache: if not key in self._cache:
src, dst, _ = self.edges(sorted=True) src, dst, _ = self.edges(sorted=sorted)
src = src.tousertensor() src = src.tousertensor()
dst = dst.tousertensor() dst = dst.tousertensor()
m = self.number_of_edges() m = self.number_of_edges()
...@@ -501,7 +510,7 @@ class GraphIndex(object): ...@@ -501,7 +510,7 @@ class GraphIndex(object):
Parameters Parameters
---------- ----------
adj : adj : scipy sparse matrix
""" """
self.clear() self.clear()
self.add_nodes(adj.shape[0]) self.add_nodes(adj.shape[0])
...@@ -510,9 +519,18 @@ class GraphIndex(object): ...@@ -510,9 +519,18 @@ class GraphIndex(object):
dst = utils.toindex(adj_coo.col) dst = utils.toindex(adj_coo.col)
self.add_edges(src, dst) self.add_edges(src, dst)
def line_graph(self): def line_graph(self, backtracking=True, sorted=False):
"""Return the line graph of this graph. """Return the line graph of this graph.
Parameters
----------
backtracking : bool, optional (default=False)
Whether (i, j) ~ (j, i) in L(G).
(i, j) ~ (j, i) is the behavior of networkx.line_graph.
sorted : bool, optional (default=False)
If true, nodes in L(G) are sorted as pairs.
If False, nodes in L(G) are ordered by their edge id's in G.
Returns Returns
------- -------
GraphIndex GraphIndex
...@@ -520,12 +538,15 @@ class GraphIndex(object): ...@@ -520,12 +538,15 @@ class GraphIndex(object):
""" """
m = self.number_of_edges() m = self.number_of_edges()
ctx = F.get_context(F.ones(1)) ctx = F.get_context(F.ones(1))
inc = F.to_scipy_sparse(self.incidence_matrix(oriented=True).get(ctx)) inc = F.to_scipy_sparse(self.incidence_matrix(True, sorted).get(ctx))
adj = inc.transpose().dot(inc).tocoo() adj = inc.transpose().dot(inc).tocoo()
if backtracking:
adj.data[adj.data >= 0] = 0
else:
adj.data[adj.data != -1] = 0 adj.data[adj.data != -1] = 0
adj.eliminate_zeros() adj.eliminate_zeros()
u, v, _ = self.edges(sorted=True) # TODO(gaiyu): sorted u, v, _ = self.edges(sorted=sorted)
u = u.tousertensor() u = u.tousertensor()
v = v.tousertensor() v = v.tousertensor()
src = F.gather_row(v, F.tensor(adj.row, dtype=F.int64)) src = F.gather_row(v, F.tensor(adj.row, dtype=F.int64))
...@@ -539,7 +560,7 @@ class GraphIndex(object): ...@@ -539,7 +560,7 @@ class GraphIndex(object):
lg.from_scipy_sparse_matrix(adj) lg.from_scipy_sparse_matrix(adj)
return lg return lg
def _line_graph(self, backtracking): def _line_graph(self, backtracking=True, sorted=False):
handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking) handle = _CAPI_DGLGraphLineGraph(self._handle, backtracking)
return GraphIndex(handle) return GraphIndex(handle)
......
...@@ -25,7 +25,7 @@ Graph GraphOp::LineGraph(const Graph* g, bool backtracking){ ...@@ -25,7 +25,7 @@ Graph GraphOp::LineGraph(const Graph* g, bool backtracking){
auto j = adj.find(v); auto j = adj.find(v);
if (j != adj.end()) { if (j != adj.end()) {
for (size_t k = 0; k != j->second.size(); ++k) { for (size_t k = 0; k != j->second.size(); ++k) {
if (j->second[k].first != u) { if (backtracking || (!backtracking && j->second[k].first != u)) {
lg_src.push_back(i); lg_src.push_back(i);
lg_dst.push_back(j->second[k].second); lg_dst.push_back(j->second[k].second);
} }
......
...@@ -10,8 +10,9 @@ b = sp.sparse.triu(a, 1) + sp.sparse.triu(a, 1).transpose() ...@@ -10,8 +10,9 @@ b = sp.sparse.triu(a, 1) + sp.sparse.triu(a, 1).transpose()
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.from_scipy_sparse_matrix(b) g.from_scipy_sparse_matrix(b)
lg_sparse = g.line_graph() backtracking = True
lg_cpp = g._line_graph() lg_sparse = g.line_graph(backtracking)
lg_cpp = g._line_graph(backtracking)
assert lg_sparse.number_of_nodes() == lg_cpp.number_of_nodes() assert lg_sparse.number_of_nodes() == lg_cpp.number_of_nodes()
assert lg_sparse.number_of_edges() == lg_cpp.number_of_edges() assert lg_sparse.number_of_edges() == lg_cpp.number_of_edges()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment