"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "986cc9b2f4cdabbb779c1991887d1d4f8e5880c5"
Commit 455ea485 authored by Lingfan Yu's avatar Lingfan Yu Committed by Minjie Wang
Browse files

[Bugfix] Conversion between networkx and DGLGraph (#244)

* copy feature to networkx

* fix to_network for multi-graph

* test case for nx conversion

* iterate over Index returns plain int instead of numpy.int64

* fix from_network multi-edge bug
parent 31800e71
...@@ -3,6 +3,7 @@ from __future__ import absolute_import ...@@ -3,6 +3,7 @@ from __future__ import absolute_import
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from collections import defaultdict
import dgl import dgl
from .base import ALL, is_all, DGLError, dgl_warning from .base import ALL, is_all, DGLError, dgl_warning
...@@ -14,6 +15,7 @@ from . import utils ...@@ -14,6 +15,7 @@ from . import utils
from .view import NodeView, EdgeView from .view import NodeView, EdgeView
from .udf import NodeBatch, EdgeBatch from .udf import NodeBatch, EdgeBatch
__all__ = ['DGLGraph'] __all__ = ['DGLGraph']
class DGLGraph(object): class DGLGraph(object):
...@@ -1014,9 +1016,15 @@ class DGLGraph(object): ...@@ -1014,9 +1016,15 @@ class DGLGraph(object):
The nx graph The nx graph
""" """
nx_graph = self._graph.to_networkx() nx_graph = self._graph.to_networkx()
#TODO(minjie): attributes if node_attrs is not None:
dgl_warning('to_networkx currently does not support converting' for nid, attr in nx_graph.nodes(data=True):
' node/edge features automatically.') nf = self.get_n_repr(nid)
attr.update({key: nf[key].squeeze(0) for key in node_attrs})
if edge_attrs is not None:
for u, v, attr in nx_graph.edges(data=True):
eid = attr['id']
ef = self.get_e_repr(eid)
attr.update({key: ef[key].squeeze(0) for key in edge_attrs})
return nx_graph return nx_graph
def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None): def from_networkx(self, nx_graph, node_attrs=None, edge_attrs=None):
...@@ -1046,18 +1054,24 @@ class DGLGraph(object): ...@@ -1046,18 +1054,24 @@ class DGLGraph(object):
else: else:
return F.tensor(lst) return F.tensor(lst)
if node_attrs is not None: if node_attrs is not None:
attr_dict = {attr : [] for attr in node_attrs} attr_dict = defaultdict(list)
for nid in range(self.number_of_nodes()): for nid in range(self.number_of_nodes()):
for attr in node_attrs: for attr in node_attrs:
attr_dict[attr].append(nx_graph.nodes[nid][attr]) attr_dict[attr].append(nx_graph.nodes[nid][attr])
for attr in node_attrs: for attr in node_attrs:
self._node_frame[attr] = _batcher(attr_dict[attr]) self._node_frame[attr] = _batcher(attr_dict[attr])
if edge_attrs is not None: if edge_attrs is not None:
attr_dict = {attr : [] for attr in edge_attrs} has_edge_id = 'id' in next(iter(nx_graph.edges(data=True)))[-1]
src, dst, _ = self._graph.edges() attr_dict = defaultdict(lambda: [None] * self.number_of_edges())
for u, v in zip(src.tolist(), dst.tolist()): if has_edge_id:
for attr in edge_attrs: for u, v, attrs in nx_graph.edges(data=True):
attr_dict[attr].append(nx_graph.edges[u, v][attr]) for key in edge_attrs:
attr_dict[key][attrs['id']] = attrs[key]
else:
# XXX: assuming networkx iteration order is deterministic
for eid, (_, _, attr) in enumerate(nx_graph.edges(data=True)):
for key in edge_attrs:
attr_dict[key][eid] = attrs[key]
for attr in edge_attrs: for attr in edge_attrs:
self._edge_frame[attr] = _batcher(attr_dict[attr]) self._edge_frame[attr] = _batcher(attr_dict[attr])
......
...@@ -21,7 +21,8 @@ class Index(object): ...@@ -21,7 +21,8 @@ class Index(object):
self._dispatch(data) self._dispatch(data)
def __iter__(self): def __iter__(self):
return iter(self.tolist()) for i in self.tolist():
yield int(i)
def __len__(self): def __len__(self):
if self._list_data is not None and isinstance(self._list_data, slice): if self._list_data is not None and isinstance(self._list_data, slice):
...@@ -39,7 +40,7 @@ class Index(object): ...@@ -39,7 +40,7 @@ class Index(object):
return len(self._dgl_tensor_data) return len(self._dgl_tensor_data)
def __getitem__(self, i): def __getitem__(self, i):
return self.tolist()[i] return int(self.tolist()[i])
def _dispatch(self, data): def _dispatch(self, data):
"""Store data based on its type.""" """Store data based on its type."""
......
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import dgl import dgl
from dgl.graph import DGLGraph from dgl.graph import DGLGraph
import utils as U import utils as U
from collections import defaultdict as ddict
D = 5 D = 5
reduce_msg_shapes = set() reduce_msg_shapes = set()
...@@ -137,6 +138,78 @@ def test_batch_setter_autograd(): ...@@ -137,6 +138,78 @@ def test_batch_setter_autograd():
check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.])) check_eq(h1.grad[:,0], th.tensor([2., 0., 0., 2., 2., 2., 2., 2., 0., 2.]))
check_eq(hh.grad[:,0], th.tensor([2., 2., 2.])) check_eq(hh.grad[:,0], th.tensor([2., 2., 2.]))
def test_nx_conversion():
# check conversion between networkx and DGLGraph
def _check_nx_feature(nxg, nf, ef):
num_nodes = len(nxg)
num_edges = nxg.size()
if num_nodes > 0:
node_feat = ddict(list)
for nid, attr in nxg.nodes(data=True):
assert len(attr) == len(nf)
for k in nxg.nodes[nid]:
node_feat[k].append(attr[k].unsqueeze(0))
for k in node_feat:
feat = th.cat(node_feat[k], dim=0)
assert U.allclose(feat, nf[k])
else:
assert len(nf) == 0
if num_edges > 0:
edge_feat = ddict(lambda: [0] * num_edges)
for u, v, attr in nxg.edges(data=True):
assert len(attr) == len(ef) + 1 # extra id
eid = attr['id']
for k in ef:
edge_feat[k][eid] = attr[k].unsqueeze(0)
for k in edge_feat:
feat = th.cat(edge_feat[k], dim=0)
assert U.allclose(feat, ef[k])
else:
assert len(ef) == 0
n1 = th.randn(5, 3)
n2 = th.randn(5, 10)
n3 = th.randn(5, 4)
e1 = th.randn(4, 5)
e2 = th.randn(4, 7)
g = DGLGraph(multigraph=True)
g.add_nodes(5)
g.add_edges([0,1,3,4], [2,4,0,3])
g.ndata.update({'n1': n1, 'n2': n2, 'n3': n3})
g.edata.update({'e1': e1, 'e2': e2})
# convert to networkx
nxg = g.to_networkx(node_attrs=['n1', 'n3'], edge_attrs=['e1', 'e2'])
assert len(nxg) == 5
assert nxg.size() == 4
_check_nx_feature(nxg, {'n1': n1, 'n3': n3}, {'e1': e1, 'e2': e2})
# convert to DGLGraph
# use id feature to test non-tensor copy
g.from_networkx(nxg, node_attrs=['n1'], edge_attrs=['e1', 'id'])
assert g.number_of_nodes() == 5
assert g.number_of_edges() == 4
assert U.allclose(g.get_n_repr()['n1'], n1)
assert U.allclose(g.get_e_repr()['e1'], e1)
assert th.equal(g.get_e_repr()['id'], th.arange(4))
g.pop_e_repr('id')
# test modifying DGLGraph
new_n = th.randn(2, 3)
new_e = th.randn(3, 5)
g.add_nodes(2, data={'n1': new_n})
# add three edges, one is a multi-edge
g.add_edges([3, 6, 0], [4, 5, 2], data={'e1': new_e})
n1 = th.cat((n1, new_n), dim=0)
e1 = th.cat((e1, new_e), dim=0)
# convert to networkx again
nxg = g.to_networkx(node_attrs=['n1'], edge_attrs=['e1'])
assert len(nxg) == 7
assert nxg.size() == 7
_check_nx_feature(nxg, {'n1': n1}, {'e1': e1})
def test_batch_send(): def test_batch_send():
g = generate_graph() g = generate_graph()
def _fmsg(edges): def _fmsg(edges):
...@@ -340,7 +413,7 @@ def test_update_all_0deg(): ...@@ -340,7 +413,7 @@ def test_update_all_0deg():
# initializer and applied with UDF. # initializer and applied with UDF.
assert U.allclose(new_repr[1:], 2*(2+th.zeros((4,5)))) assert U.allclose(new_repr[1:], 2*(2+th.zeros((4,5))))
assert U.allclose(new_repr[0], 2 * old_repr.sum(0)) assert U.allclose(new_repr[0], 2 * old_repr.sum(0))
# test#2: graph with no edge # test#2: graph with no edge
g = DGLGraph() g = DGLGraph()
g.add_nodes(5) g.add_nodes(5)
...@@ -513,6 +586,7 @@ def test_dynamic_addition(): ...@@ -513,6 +586,7 @@ def test_dynamic_addition():
if __name__ == '__main__': if __name__ == '__main__':
test_nx_conversion()
test_batch_setter_getter() test_batch_setter_getter()
test_batch_setter_autograd() test_batch_setter_autograd()
test_batch_send() test_batch_send()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment