"tests/vscode:/vscode.git/clone" did not exist on "603dbf72e4529868bcefd68bd5f901b84093626e"
Commit fb6be9fb authored by Da Zheng's avatar Da Zheng Committed by Minjie Wang
Browse files

[GraphIndex] Create graph index directly from scipy matrix. (#87)

* create synthetic data with scipy.

* changes as comments
parent 4af3f8bc
...@@ -206,10 +206,33 @@ def get_gnp_generator(args): ...@@ -206,10 +206,33 @@ def get_gnp_generator(args):
return nx.fast_gnp_random_graph(n, p, seed, True) return nx.fast_gnp_random_graph(n, p, seed, True)
return _gen return _gen
class ScipyGraph(object):
"""A simple graph object that uses scipy matrix."""
def __init__(self, mat):
self._mat = mat
def get_graph(self):
return self._mat
def number_of_nodes(self):
return self._mat.shape[0]
def number_of_edges(self):
return self._mat.getnnz()
def get_scipy_generator(args):
n = args.syn_gnp_n
p = (2 * np.log(n) / n) if args.syn_gnp_p == 0. else args.syn_gnp_p
def _gen(seed):
return ScipyGraph(sp.random(n, n, p, format='coo'))
return _gen
def load_synthetic(args): def load_synthetic(args):
ty = args.syn_type ty = args.syn_type
if ty == 'gnp': if ty == 'gnp':
gen = get_gnp_generator(args) gen = get_gnp_generator(args)
elif ty == 'scipy':
gen = get_scipy_generator(args)
else: else:
raise ValueError('Unknown graph generator type: {}'.format(ty)) raise ValueError('Unknown graph generator type: {}'.format(ty))
return GCNSyntheticDataset( return GCNSyntheticDataset(
......
...@@ -3,7 +3,7 @@ from __future__ import absolute_import ...@@ -3,7 +3,7 @@ from __future__ import absolute_import
import ctypes import ctypes
import numpy as np import numpy as np
import networkx as nx import networkx as nx
import scipy.sparse as sp import scipy
from ._ffi.base import c_array from ._ffi.base import c_array
from ._ffi.function import _init_api from ._ffi.function import _init_api
...@@ -600,30 +600,59 @@ class GraphIndex(object): ...@@ -600,30 +600,59 @@ class GraphIndex(object):
return GraphIndex(handle) return GraphIndex(handle)
class SubgraphIndex(GraphIndex): class SubgraphIndex(GraphIndex):
def __init__(self, handle, parent, induced_nodes, induced_edges): """Graph index for subgraph.
super().__init__(handle)
Parameters
----------
handle : GraphIndexHandle
The capi handle.
paranet : GraphIndex
The parent graph index.
induced_nodes : utils.Index
The parent node ids in this subgraph.
induced_edges : utils.Index
The parent edge ids in this subgraph.
"""
def __init__(self, handle, parent, induced_nodes, induced_edges):
super(SubgraphIndex, self).__init__(handle)
self._parent = parent self._parent = parent
self._induced_nodes = induced_nodes self._induced_nodes = induced_nodes
self._induced_edges = induced_edges self._induced_edges = induced_edges
def add_nodes(self, num): def add_nodes(self, num):
"""Add nodes. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edge(self, u, v): def add_edge(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
def add_edges(self, u, v): def add_edges(self, u, v):
"""Add edges. Disabled because SubgraphIndex is read-only."""
raise RuntimeError('Readonly graph. Mutation is not allowed.') raise RuntimeError('Readonly graph. Mutation is not allowed.')
@property
def induced_edges(self):
return self._induced_edges
@property @property
def induced_nodes(self): def induced_nodes(self):
"""Return parent node ids.
Returns
-------
utils.Index
The parent node ids.
"""
return self._induced_nodes return self._induced_nodes
@property
def induced_edges(self):
"""Return parent edge ids.
Returns
-------
utils.Index
The parent edge ids.
"""
return self._induced_edges
def disjoint_union(graphs): def disjoint_union(graphs):
"""Return a disjoint union of the input graphs. """Return a disjoint union of the input graphs.
...@@ -697,8 +726,25 @@ def create_graph_index(graph_data=None, multigraph=False): ...@@ -697,8 +726,25 @@ def create_graph_index(graph_data=None, multigraph=False):
handle = _CAPI_DGLGraphCreate(multigraph) handle = _CAPI_DGLGraphCreate(multigraph)
gi = GraphIndex(handle) gi = GraphIndex(handle)
if graph_data is not None:
if graph_data is None:
return gi
# scipy format
if isinstance(graph_data, scipy.sparse.spmatrix):
try:
gi.from_scipy_sparse_matrix(graph_data)
return gi
except:
raise Exception('Graph data is not a valid scipy sparse matrix.')
# networkx - any format
try:
gi.from_networkx(graph_data) gi.from_networkx(graph_data)
except:
raise Exception('Error while creating graph from input of type "%s".'
% type(graph_data))
return gi return gi
_init_api("dgl.graph_index") _init_api("dgl.graph_index")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment