"tests/python/common/vscode:/vscode.git/clone" did not exist on "690f37bbe96a301cec8709a55a9d5e716f515683"
Unverified Commit 10253a5c authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[BUGFIX] Fix is_multigraph in the construction from scipy coo matrix (#1357)

* fix is_multigraph in from_coo.

* add tests for partition.

* fix.

* Revert "add tests for partition."

This reverts commit cb8c8555da3e0c70a482c2d639adce2943475bfc.

* fix everywhere from_scipy_sparse_matrix is used.
parent 856de790
...@@ -1805,7 +1805,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1805,7 +1805,7 @@ class DGLGraph(DGLBaseGraph):
raise DGLError('Not all edges have attribute {}.'.format(attr)) raise DGLError('Not all edges have attribute {}.'.format(attr))
self._edge_frame[attr] = _batcher(attr_dict[attr]) self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, spmat): def from_scipy_sparse_matrix(self, spmat, multigraph=False):
""" Convert from scipy sparse matrix. """ Convert from scipy sparse matrix.
Parameters Parameters
...@@ -1813,6 +1813,10 @@ class DGLGraph(DGLBaseGraph): ...@@ -1813,6 +1813,10 @@ class DGLGraph(DGLBaseGraph):
spmat : scipy sparse matrix spmat : scipy sparse matrix
The graph's adjacency matrix The graph's adjacency matrix
multigraph : bool, optional
Whether the graph would be a multigraph. If the input scipy sparse matrix is CSR,
this argument is ignored.
Examples Examples
-------- --------
>>> from scipy.sparse import coo_matrix >>> from scipy.sparse import coo_matrix
...@@ -1824,7 +1828,7 @@ class DGLGraph(DGLBaseGraph): ...@@ -1824,7 +1828,7 @@ class DGLGraph(DGLBaseGraph):
>>> g.from_scipy_sparse_matrix(a) >>> g.from_scipy_sparse_matrix(a)
""" """
self.clear() self.clear()
self._graph = graph_index.from_scipy_sparse_matrix(spmat, self.is_readonly) self._graph = graph_index.from_scipy_sparse_matrix(spmat, multigraph, self.is_readonly)
self._node_frame.add_rows(self.number_of_nodes()) self._node_frame.add_rows(self.number_of_nodes())
self._edge_frame.add_rows(self.number_of_edges()) self._edge_frame.add_rows(self.number_of_edges())
self._msg_frame.add_rows(self.number_of_edges()) self._msg_frame.add_rows(self.number_of_edges())
......
...@@ -1139,12 +1139,15 @@ def from_networkx(nx_graph, readonly): ...@@ -1139,12 +1139,15 @@ def from_networkx(nx_graph, readonly):
dst = utils.toindex(dst) dst = utils.toindex(dst)
return from_coo(num_nodes, src, dst, is_multigraph, readonly) return from_coo(num_nodes, src, dst, is_multigraph, readonly)
def from_scipy_sparse_matrix(adj, readonly): def from_scipy_sparse_matrix(adj, multigraph, readonly):
"""Convert from scipy sparse matrix. """Convert from scipy sparse matrix.
Parameters Parameters
---------- ----------
adj : scipy sparse matrix adj : scipy sparse matrix
multigraph : bool
Whether the graph would be a multigraph. If none, the flag will be determined
by the data.
readonly : bool readonly : bool
True if the returned graph is readonly. True if the returned graph is readonly.
...@@ -1156,8 +1159,9 @@ def from_scipy_sparse_matrix(adj, readonly): ...@@ -1156,8 +1159,9 @@ def from_scipy_sparse_matrix(adj, readonly):
if adj.getformat() != 'csr' or not readonly: if adj.getformat() != 'csr' or not readonly:
num_nodes = max(adj.shape[0], adj.shape[1]) num_nodes = max(adj.shape[0], adj.shape[1])
adj_coo = adj.tocoo() adj_coo = adj.tocoo()
return from_coo(num_nodes, adj_coo.row, adj_coo.col, False, readonly) return from_coo(num_nodes, adj_coo.row, adj_coo.col, multigraph, readonly)
else: else:
# If the input matrix is csr, it's guaranteed to be a simple graph.
return from_csr(adj.indptr, adj.indices, False, "out") return from_csr(adj.indptr, adj.indices, False, "out")
def from_edge_list(elist, is_multigraph, readonly): def from_edge_list(elist, is_multigraph, readonly):
...@@ -1298,7 +1302,7 @@ def create_graph_index(graph_data, multigraph, readonly): ...@@ -1298,7 +1302,7 @@ def create_graph_index(graph_data, multigraph, readonly):
return from_edge_list(graph_data, multigraph, readonly) return from_edge_list(graph_data, multigraph, readonly)
elif isinstance(graph_data, scipy.sparse.spmatrix): elif isinstance(graph_data, scipy.sparse.spmatrix):
# scipy format # scipy format
return from_scipy_sparse_matrix(graph_data, readonly) return from_scipy_sparse_matrix(graph_data, multigraph, readonly)
else: else:
# networkx - any format # networkx - any format
try: try:
......
...@@ -202,7 +202,8 @@ def create_large_graph_index(num_nodes): ...@@ -202,7 +202,8 @@ def create_large_graph_index(num_nodes):
row = np.random.choice(num_nodes, num_nodes * 10) row = np.random.choice(num_nodes, num_nodes * 10)
col = np.random.choice(num_nodes, num_nodes * 10) col = np.random.choice(num_nodes, num_nodes * 10)
spm = spsp.coo_matrix((np.ones(len(row)), (row, col))) spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))
return from_scipy_sparse_matrix(spm, True) # It's possible that we generate a multigraph.
return from_scipy_sparse_matrix(spm, True, True)
def get_nodeflow(g, node_ids, num_layers): def get_nodeflow(g, node_ids, num_layers):
batch_size = len(node_ids) batch_size = len(node_ids)
......
...@@ -92,7 +92,8 @@ def create_large_graph_index(num_nodes): ...@@ -92,7 +92,8 @@ def create_large_graph_index(num_nodes):
row = np.random.choice(num_nodes, num_nodes * 10) row = np.random.choice(num_nodes, num_nodes * 10)
col = np.random.choice(num_nodes, num_nodes * 10) col = np.random.choice(num_nodes, num_nodes * 10)
spm = spsp.coo_matrix((np.ones(len(row)), (row, col))) spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))
return from_scipy_sparse_matrix(spm, True) # It's possible that we generate a multigraph.
return from_scipy_sparse_matrix(spm, True, True)
def test_node_subgraph_with_halo(): def test_node_subgraph_with_halo():
gi = create_large_graph_index(1000) gi = create_large_graph_index(1000)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment