Commit 79f62400 authored by Minjie Wang's avatar Minjie Wang
Browse files

WIP: batch graphs

parent 4aebfd7b
...@@ -66,7 +66,9 @@ class SST(object): ...@@ -66,7 +66,9 @@ class SST(object):
# add root # add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label())) g.add_node(0, x=SST.PAD_WORD, y=int(root.label()))
_rec_build(0, root) _rec_build(0, root)
return dgl.DGLGraph(g) ret = DGLGraph()
ret.from_networkx(g)
return ret
def __getitem__(self, idx): def __getitem__(self, idx):
return self.trees[idx] return self.trees[idx]
......
...@@ -45,9 +45,10 @@ class DGLGraph(object): ...@@ -45,9 +45,10 @@ class DGLGraph(object):
# frame # frame
self._node_frame = node_frame if node_frame is not None else FrameRef() self._node_frame = node_frame if node_frame is not None else FrameRef()
self._edge_frame = edge_frame if edge_frame is not None else FrameRef() self._edge_frame = edge_frame if edge_frame is not None else FrameRef()
# other class members # msg graph & frame
self._msg_graph = create_graph_index() self._msg_graph = create_graph_index()
self._msg_frame = FrameRef() self._msg_frame = FrameRef()
# registered functions
self._message_func = (None, None) self._message_func = (None, None)
self._reduce_func = (None, None) self._reduce_func = (None, None)
self._edge_func = (None, None) self._edge_func = (None, None)
...@@ -111,6 +112,12 @@ class DGLGraph(object): ...@@ -111,6 +112,12 @@ class DGLGraph(object):
self._msg_graph.clear() self._msg_graph.clear()
self._msg_frame.clear() self._msg_frame.clear()
def clear_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def number_of_nodes(self): def number_of_nodes(self):
"""Return the number of nodes. """Return the number of nodes.
...@@ -422,6 +429,23 @@ class DGLGraph(object): ...@@ -422,6 +429,23 @@ class DGLGraph(object):
#TODO: attributes #TODO: attributes
return nx_graph return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
"""Convert from networkx graph.
If 'id' edge attribute exists, the edge will be added follows
the edge id order. Otherwise, order is undefined.
Parameters
----------
nx_graph : networkx.DiGraph
The nx graph
"""
self.clear()
self._graph.from_networkx(nx_graph)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
#TODO: attributes
pass
def node_attr_schemes(self): def node_attr_schemes(self):
"""Return the node attribute schemes. """Return the node attribute schemes.
...@@ -1303,12 +1327,6 @@ class DGLGraph(object): ...@@ -1303,12 +1327,6 @@ class DGLGraph(object):
self._edge_frame.num_rows, self._edge_frame.num_rows,
reduce_func) reduce_func)
def clear_messages(self):
"""Clear all messages."""
self._msg_graph.clear()
self._msg_frame.clear()
self._msg_graph.add_nodes(self.number_of_nodes())
def _get_repr(attr_dict): def _get_repr(attr_dict):
if len(attr_dict) == 1 and __REPR__ in attr_dict: if len(attr_dict) == 1 and __REPR__ in attr_dict:
return attr_dict[__REPR__] return attr_dict[__REPR__]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment