"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "6a363378d50c4c1d6c01b1942b170050286e2923"
Commit 882e2a7b authored by Minjie Wang's avatar Minjie Wang
Browse files

Graph batching. Support convert nx graph attrs

parent 314a75f3
......@@ -31,11 +31,11 @@ class BatchedDGLGraph(DGLGraph):
# NOTE: following code will materialize the columns of the input graphs.
batched_node_frame = FrameRef()
for gr in graph_list:
cols = {gr._node_frame[key] for key in node_attrs}
cols = {key : gr._node_frame[key] for key in node_attrs}
batched_node_frame.append(cols)
batched_edge_frame = FrameRef()
for gr in graph_list:
cols = {gr._edge_frame[key] for key in edge_attrs}
cols = {key : gr._edge_frame[key] for key in edge_attrs}
batched_edge_frame.append(cols)
super(BatchedDGLGraph, self).__init__(
graph_data=batched_index,
......@@ -169,12 +169,12 @@ def batch(graph_list, node_attrs=ALL, edge_attrs=ALL):
node_attrs = []
elif is_all(node_attrs):
node_attrs = graph_list[0].node_attr_schemes()
elif if isinstance(node_attrs, str):
elif isinstance(node_attrs, str):
node_attrs = [node_attrs]
if edge_attrs is None:
edge_attrs = []
elif is_all(edge_attrs):
edge_attrs = graph_list[0].edge_attr_schemes()
elif if isinstance(edge_attrs, str):
elif isinstance(edge_attrs, str):
edge_attrs = [edge_attrs]
return BatchedDGLGraph(graph_list, node_attrs, edge_attrs)
......@@ -67,7 +67,7 @@ class SST(object):
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()))
_rec_build(0, root)
ret = DGLGraph()
ret.from_networkx(g)
ret.from_networkx(g, node_attrs=['x', 'y'])
return ret
def __getitem__(self, idx):
......
......@@ -439,12 +439,34 @@ class DGLGraph(object):
----------
nx_graph : networkx.DiGraph
The nx graph
node_attrs : iterable of str, optional
The node attributes needs to be copied.
edge_attrs : iterable of str, optional
The edge attributes needs to be copied.
"""
self.clear()
self._graph.from_networkx(nx_graph)
self._msg_graph.add_nodes(self._graph.number_of_nodes())
#TODO: attributes
pass
def _batcher(lst):
if isinstance(lst[0], Tensor):
return F.pack([F.unsqueeze(x, 0) for x in lst])
else:
return F.tensor(lst)
if node_attrs is not None:
attr_dict = {attr : [] for attr in node_attrs}
for nid in range(self.number_of_nodes()):
for attr in node_attrs:
attr_dict[attr].append(nx_graph.nodes[nid][attr])
for attr in node_attrs:
self._node_frame[attr] = _batcher(attr_dict[attr])
if edge_attrs is not None:
attr_dict = {attr : [] for attr in edge_attrs}
src, dst, _ = self._graph.edges()
for u, v in zip(src.tolist(), dst.tolist()):
for attr in edge_attrs:
attr_dict[attr].append(nx_graph.edges[u, v][attr])
for attr in edge_attrs:
self._edge_frame[attr] = _batcher(attr_dict[attr])
def node_attr_schemes(self):
"""Return the node attribute schemes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment