Commit 44db98c4 authored by Minjie Wang's avatar Minjie Wang
Browse files

remove nx inheritence

parent 916d375b
...@@ -25,9 +25,9 @@ class DGLGraph(object): ...@@ -25,9 +25,9 @@ class DGLGraph(object):
---------- ----------
graph_data : graph data graph_data : graph data
Data to initialize graph. Same as networkx's semantics. Data to initialize graph. Same as networkx's semantics.
node_frame : dgl.frame.Frame node_frame : FrameRef
Node feature storage. Node feature storage.
edge_frame : dgl.frame.Frame edge_frame : FrameRef
Edge feature storage. Edge feature storage.
attr : keyword arguments, optional attr : keyword arguments, optional
Attributes to add to graph as key=value pairs. Attributes to add to graph as key=value pairs.
...@@ -37,14 +37,7 @@ class DGLGraph(object): ...@@ -37,14 +37,7 @@ class DGLGraph(object):
node_frame=None, node_frame=None,
edge_frame=None, edge_frame=None,
**attr): **attr):
# TODO(minjie): maintaining node/edge list is costly when graph is large. # TODO: keyword attr
#nx_init(self,
# self._add_node_callback,
# self._add_edge_callback,
# self._del_node_callback,
# self._del_edge_callback,
# graph_data,
# **attr)
# graph # graph
self._graph = GraphIndex(graph_data) self._graph = GraphIndex(graph_data)
# frame # frame
...@@ -59,10 +52,380 @@ class DGLGraph(object): ...@@ -59,10 +52,380 @@ class DGLGraph(object):
self._apply_node_func = (None, None) self._apply_node_func = (None, None)
self._apply_edge_func = (None, None) self._apply_edge_func = (None, None)
def add_nodes(self, num, reprs=None):
"""Add nodes.
Parameters
----------
num : int
Number of nodes to be added.
reprs : dict
Optional node representations.
"""
self._graph.add_nodes(num)
#TODO(minjie): change frames
def add_edge(self, u, v, repr=None):
"""Add one edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
repr : dict
Optional edge representation.
"""
self._graph.add_edge(u, v)
#TODO(minjie): change frames
def add_edges(self, u, v, reprs=None):
"""Add many edges.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
reprs : dict
Optional node representations.
"""
u = utils.toindex(u)
v = utils.toindex(v)
self._graph.add_edges(u, v)
#TODO(minjie): change frames
def clear(self):
"""Clear the graph and its storage."""
self._graph.clear()
self._node_frame.clear()
self._edge_frame.clear()
self._msg_graph.clear()
self._msg_frame.clear()
def number_of_nodes(self):
"""Return the number of nodes.
Returns
-------
int
The number of nodes
"""
return self._graph.number_of_nodes()
def number_of_edges(self):
"""Return the number of edges.
Returns
-------
int
The number of edges
"""
return self._graph.number_of_edges()
def has_node(self, vid):
"""Return true if the node exists.
Parameters
----------
vid : int
The nodes
Returns
-------
bool
True if the node exists
"""
return self.has_node(vid)
def has_nodes(self, vids):
"""Return true if the nodes exist.
Parameters
----------
vid : list, tensor
The nodes
Returns
-------
tensor
0-1 array indicating existence
"""
vids = utils.toindex(vids)
rst = self._graph.has_nodes(vids)
return rst.tousertensor()
def has_edge(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
bool
True if the edge exists
"""
return self._graph.has_edge(u, v)
def has_edges(self, u, v):
"""Return true if the edge exists.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
0-1 array indicating existence
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.has_edges(u, v)
return rst.tousertensor()
def predecessors(self, v, radius=1):
"""Return the predecessors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of predecessors
"""
return self._graph.predecessors(v).tousertensor()
def successors(self, v, radius=1):
"""Return the successors of the node.
Parameters
----------
v : int
The node.
radius : int, optional
The radius of the neighborhood.
Returns
-------
tensor
Array of successors
"""
return self._graph.successors(v).tousertensor()
def edge_id(self, u, v):
"""Return the id of the edge.
Parameters
----------
u : int
The src node.
v : int
The dst node.
Returns
-------
int
The edge id.
"""
return self._graph.edge_id(u, v)
def edge_ids(self, u, v):
"""Return the edge ids.
Parameters
----------
u : list, tensor
The src nodes.
v : list, tensor
The dst nodes.
Returns
-------
tensor
The edge id array.
"""
u = utils.toindex(u)
v = utils.toindex(v)
rst = self._graph.edge_ids(u, v)
return rst.tousertensor()
def in_edges(self, v):
"""Return the in edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v = utils.toindex(v)
src, dst, eid = self._graph.in_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def out_edges(self, v):
"""Return the out edges of the node(s).
Parameters
----------
v : int, list, tensor
The node(s).
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
v = utils.toindex(v)
src, dst, eid = self._graph.out_edges(v)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def edges(self, sorted=False):
"""Return all the edges.
Parameters
----------
sorted : bool
True if the returned edges are sorted by their ids.
Returns
-------
tensor
The src nodes.
tensor
The dst nodes.
tensor
The edge ids.
"""
src, dst, eid = self._graph.edges(sorted)
return src.tousertensor(), dst.tousertensor(), eid.tousertensor()
def in_degree(self, v):
"""Return the in degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The in degree.
"""
return self._graph.in_degree(v)
def in_degrees(self, v):
"""Return the in degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The in degree array.
"""
return self._graph.in_degrees(v).tousertensor()
def out_degree(self, v):
"""Return the out degree of the node.
Parameters
----------
v : int
The node.
Returns
-------
int
The out degree.
"""
return self._graph.out_degree(v)
def out_degrees(self, v):
"""Return the out degrees of the nodes.
Parameters
----------
v : list, tensor
The nodes.
Returns
-------
tensor
The out degree array.
"""
return self._graph.out_degrees(v).tousertensor()
def to_networkx(self, node_attrs=None, edge_attrs=None):
"""Convert to networkx graph.
The edge id will be saved as the 'id' edge attribute.
Parameters
----------
node_attrs : iterable of str, optional
The node attributes to be copied.
edge_attrs : iterable of str, optional
The edge attributes to be copied.
Returns
-------
networkx.DiGraph
The nx graph
"""
nx_graph = self._graph.to_networkx()
#TODO: attributes
return nx_graph
def node_attr_schemes(self): def node_attr_schemes(self):
"""Return the node attribute schemes.
Returns
-------
iterable
The set of attribute names
"""
return self._node_frame.schemes return self._node_frame.schemes
def edge_attr_schemes(self): def edge_attr_schemes(self):
"""Return the edge attribute schemes.
Returns
-------
iterable
The set of attribute names
"""
return self._edge_frame.schemes return self._edge_frame.schemes
def set_n_repr(self, hu, u=ALL): def set_n_repr(self, hu, u=ALL):
...@@ -113,7 +476,12 @@ class DGLGraph(object): ...@@ -113,7 +476,12 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
u : node, container or tensor u : node, container or tensor
The node(s). The node(s).
Returns
-------
dict
Representation dict
""" """
if is_all(u): if is_all(u):
if len(self._node_frame) == 1 and __REPR__ in self._node_frame: if len(self._node_frame) == 1 and __REPR__ in self._node_frame:
...@@ -133,7 +501,12 @@ class DGLGraph(object): ...@@ -133,7 +501,12 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
key : str key : str
The attribute name. The attribute name.
Returns
-------
Tensor
The popped representation
""" """
return self._node_frame.pop(key) return self._node_frame.pop(key)
...@@ -229,6 +602,11 @@ class DGLGraph(object): ...@@ -229,6 +602,11 @@ class DGLGraph(object):
The source node(s). The source node(s).
v : node, container or tensor v : node, container or tensor
The destination node(s). The destination node(s).
Returns
-------
dict
Representation dict
""" """
u_is_all = is_all(u) u_is_all = is_all(u)
v_is_all = is_all(v) v_is_all = is_all(v)
...@@ -254,6 +632,11 @@ class DGLGraph(object): ...@@ -254,6 +632,11 @@ class DGLGraph(object):
---------- ----------
key : str key : str
The attribute name. The attribute name.
Returns
-------
Tensor
The popped representation
""" """
return self._edge_frame.pop(key) return self._edge_frame.pop(key)
...@@ -264,6 +647,11 @@ class DGLGraph(object): ...@@ -264,6 +647,11 @@ class DGLGraph(object):
---------- ----------
eid : int, container or tensor eid : int, container or tensor
The edge id(s). The edge id(s).
Returns
-------
dict
Representation dict
""" """
if is_all(eid): if is_all(eid):
if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame: if len(self._edge_frame) == 1 and __REPR__ in self._edge_frame:
...@@ -368,6 +756,7 @@ class DGLGraph(object): ...@@ -368,6 +756,7 @@ class DGLGraph(object):
new_repr = apply_node_func(self.get_n_repr(v)) new_repr = apply_node_func(self.get_n_repr(v))
self.set_n_repr(new_repr, v) self.set_n_repr(new_repr, v)
else: else:
raise RuntimeError('Disabled')
if is_all(v): if is_all(v):
v = self.nodes() v = self.nodes()
v = utils.toindex(v) v = utils.toindex(v)
...@@ -441,6 +830,7 @@ class DGLGraph(object): ...@@ -441,6 +830,7 @@ class DGLGraph(object):
self._nonbatch_send(u, v, message_func) self._nonbatch_send(u, v, message_func)
def _nonbatch_send(self, u, v, message_func): def _nonbatch_send(self, u, v, message_func):
raise RuntimeError('Disabled')
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self.cached_graph.edges()
else: else:
...@@ -505,6 +895,7 @@ class DGLGraph(object): ...@@ -505,6 +895,7 @@ class DGLGraph(object):
self._nonbatch_update_edge(u, v, edge_func) self._nonbatch_update_edge(u, v, edge_func)
def _nonbatch_update_edge(self, u, v, edge_func): def _nonbatch_update_edge(self, u, v, edge_func):
raise RuntimeError('Disabled')
if is_all(u) and is_all(v): if is_all(u) and is_all(v):
u, v = self.cached_graph.edges() u, v = self.cached_graph.edges()
else: else:
...@@ -587,6 +978,7 @@ class DGLGraph(object): ...@@ -587,6 +978,7 @@ class DGLGraph(object):
self.apply_nodes(u, apply_node_func, batchable) self.apply_nodes(u, apply_node_func, batchable)
def _nonbatch_recv(self, u, reduce_func): def _nonbatch_recv(self, u, reduce_func):
raise RuntimeError('Disabled')
if is_all(u): if is_all(u):
u = list(range(0, self.number_of_nodes())) u = list(range(0, self.number_of_nodes()))
else: else:
...@@ -916,75 +1308,9 @@ class DGLGraph(object): ...@@ -916,75 +1308,9 @@ class DGLGraph(object):
self._edge_frame.num_rows, self._edge_frame.num_rows,
reduce_func) reduce_func)
def draw(self):
"""Plot the graph using dot."""
from networkx.drawing.nx_agraph import graphviz_layout
pos = graphviz_layout(self, prog='dot')
nx.draw(self, pos, with_labels=True)
@property
def msg_graph(self):
# TODO: dirty flag when mutated
if self._msg_graph is None:
self._msg_graph = CachedGraph()
self._msg_graph.add_nodes(self.number_of_nodes())
return self._msg_graph
def clear_messages(self): def clear_messages(self):
if self._msg_graph is not None: self._msg_graph.clear()
self._msg_graph = CachedGraph() self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
self._msg_frame.clear()
@property
def edge_list(self):
"""Return edges in the addition order."""
return self._edge_list
def get_edge_id(self, u, v):
"""Return the continuous edge id(s) assigned.
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
Returns
-------
eid : tensor
The tensor contains edge id(s).
"""
u = utils.toindex(u)
v = utils.toindex(v)
return self.cached_graph.get_edge_id(u, v)
def _add_node_callback(self, node):
#print('New node:', node)
self._cached_graph = None
def _del_node_callback(self, node):
#print('Del node:', node)
raise RuntimeError('Node removal is not supported currently.')
node = utils.convert_to_id_tensor(node)
self._node_frame.delete_rows(node)
self._cached_graph = None
def _add_edge_callback(self, u, v):
#print('New edge:', u, v)
self._edge_list.append((u, v))
self._cached_graph = None
def _del_edge_callback(self, u, v):
#print('Del edge:', u, v)
raise RuntimeError('Edge removal is not supported currently.')
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.get_edge_id(u, v)
self._edge_frame.delete_rows(eid)
self._cached_graph = None
def _get_repr(attr_dict): def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict: if len(attr_dict) == 1 and __REPR__ in attr_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment