Commit 86f28d65 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by Zihao Ye
Browse files

[Misc] Provide a "return 1s instead of edge IDs" option in scipy adjacency matrix (#730)

* return 1 option in scipy adjacency matrix

* lint

* use dgl warning

* i'm an idiot

* lint x2

* rename
parent 7ad663c3
...@@ -13,8 +13,8 @@ def is_all(arg): ...@@ -13,8 +13,8 @@ def is_all(arg):
"""Return true if the argument is a special symbol for all nodes or edges.""" """Return true if the argument is a special symbol for all nodes or edges."""
return isinstance(arg, str) and arg == ALL return isinstance(arg, str) and arg == ALL
def dgl_warning(msg): def dgl_warning(msg, warn_type=UserWarning):
"""Print out warning messages.""" """Print out warning messages."""
warnings.warn(msg) warnings.warn(msg, warn_type)
_init_internal_api() _init_internal_api()
...@@ -3040,7 +3040,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3040,7 +3040,7 @@ class DGLGraph(DGLBaseGraph):
sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes) sgi = self._graph.edge_subgraph(induced_edges, preserve_nodes=preserve_nodes)
return subgraph.DGLSubGraph(self, sgi) return subgraph.DGLSubGraph(self, sgi)
def adjacency_matrix_scipy(self, transpose=False, fmt='csr'): def adjacency_matrix_scipy(self, transpose=False, fmt='csr', return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph. """Return the scipy adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination By default, a row of returned adjacency matrix represents the destination
...@@ -3049,14 +3049,16 @@ class DGLGraph(DGLBaseGraph): ...@@ -3049,14 +3049,16 @@ class DGLGraph(DGLBaseGraph):
When transpose is True, a row represents the source and a column represents When transpose is True, a row represents the source and a column represents
a destination. a destination.
The elements in the adajency matrix are edge ids.
Parameters Parameters
---------- ----------
transpose : bool, optional (default=False) transpose : bool, optional (default=False)
A flag to transpose the returned adjacency matrix. A flag to transpose the returned adjacency matrix.
fmt : str, optional (default='csr') fmt : str, optional (default='csr')
Indicates the format of returned adjacency matrix. Indicates the format of returned adjacency matrix.
return_edge_ids : bool, optional (default=True)
If True, the elements in the adjacency matrix are edge ids.
Note that one of the element is 0. Proceed with caution.
If False, the elements will be always 1.
Returns Returns
------- -------
...@@ -3064,7 +3066,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -3064,7 +3066,7 @@ class DGLGraph(DGLBaseGraph):
The scipy representation of adjacency matrix. The scipy representation of adjacency matrix.
""" """
return self._graph.adjacency_matrix_scipy(transpose, fmt) return self._graph.adjacency_matrix_scipy(transpose, fmt, return_edge_ids)
def adjacency_matrix(self, transpose=False, ctx=F.cpu()): def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph. """Return the adjacency matrix representation of this graph.
......
...@@ -7,7 +7,7 @@ import scipy ...@@ -7,7 +7,7 @@ import scipy
from ._ffi.object import register_object, ObjectBase from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .base import DGLError from .base import DGLError, dgl_warning
from . import backend as F from . import backend as F
from . import utils from . import utils
...@@ -577,7 +577,7 @@ class GraphIndex(ObjectBase): ...@@ -577,7 +577,7 @@ class GraphIndex(ObjectBase):
return SubgraphIndex(gidx, self, induced_nodes, e) return SubgraphIndex(gidx, self, induced_nodes, e)
@utils.cached_member(cache='_cache', prefix='scipy_adj') @utils.cached_member(cache='_cache', prefix='scipy_adj')
def adjacency_matrix_scipy(self, transpose, fmt): def adjacency_matrix_scipy(self, transpose, fmt, return_edge_ids=None):
"""Return the scipy adjacency matrix representation of this graph. """Return the scipy adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination By default, a row of returned adjacency matrix represents the destination
...@@ -586,14 +586,14 @@ class GraphIndex(ObjectBase): ...@@ -586,14 +586,14 @@ class GraphIndex(ObjectBase):
When transpose is True, a row represents the source and a column represents When transpose is True, a row represents the source and a column represents
a destination. a destination.
The elements in the adajency matrix are edge ids.
Parameters Parameters
---------- ----------
transpose : bool transpose : bool
A flag to transpose the returned adjacency matrix. A flag to transpose the returned adjacency matrix.
fmt : str fmt : str
Indicates the format of returned adjacency matrix. Indicates the format of returned adjacency matrix.
return_edge_ids : bool
Indicates whether to return edge IDs or 1 as elements.
Returns Returns
------- -------
...@@ -603,20 +603,30 @@ class GraphIndex(ObjectBase): ...@@ -603,20 +603,30 @@ class GraphIndex(ObjectBase):
if not isinstance(transpose, bool): if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,' raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose))) ' but got %s.' % (type(transpose)))
if return_edge_ids is None:
dgl_warning(
"Adjacency matrix by default currently returns edge IDs."
" As a result there is one 0 entry which is not eliminated."
" In the next release it will return 1s by default,"
" and 0 will be eliminated otherwise.",
FutureWarning)
return_edge_ids = True
rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt) rst = _CAPI_DGLGraphGetAdj(self, transpose, fmt)
if fmt == "csr": if fmt == "csr":
indptr = utils.toindex(rst(0)).tonumpy() indptr = utils.toindex(rst(0)).tonumpy()
indices = utils.toindex(rst(1)).tonumpy() indices = utils.toindex(rst(1)).tonumpy()
shuffle = utils.toindex(rst(2)).tonumpy() data = utils.toindex(rst(2)).tonumpy() if return_edge_ids else np.ones_like(indices)
n = self.number_of_nodes() n = self.number_of_nodes()
return scipy.sparse.csr_matrix((shuffle, indices, indptr), shape=(n, n)) return scipy.sparse.csr_matrix((data, indices, indptr), shape=(n, n))
elif fmt == 'coo': elif fmt == 'coo':
idx = utils.toindex(rst(0)).tonumpy() idx = utils.toindex(rst(0)).tonumpy()
n = self.number_of_nodes() n = self.number_of_nodes()
m = self.number_of_edges() m = self.number_of_edges()
row, col = np.reshape(idx, (2, m)) row, col = np.reshape(idx, (2, m))
shuffle = np.arange(0, m) data = np.arange(0, m) if return_edge_ids else np.ones_like(row)
return scipy.sparse.coo_matrix((shuffle, (row, col)), shape=(n, n)) return scipy.sparse.coo_matrix((data, (row, col)), shape=(n, n))
else: else:
raise Exception("unknown format") raise Exception("unknown format")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment