"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "f5dca445a2a22945a3a34fbf9abe409e23f83fc5"
Commit 3a868eb0 authored by Mufei Li's avatar Mufei Li Committed by Minjie Wang
Browse files

[Feature] Max readout and consecutive labeling for networkx (#341)

* Max readout and consecutive labeling

* Delete test_readout.py

* Delete test_basics.py

* Test case and fix

* Recover accidentally removed file

* Fix import order

* Fix test case

* Fix

* Fix

* Fix

* Fix

* Fix

* revert

* Fix

* Fix

* Fix
parent 707334ce
...@@ -35,3 +35,5 @@ Graph Readout ...@@ -35,3 +35,5 @@ Graph Readout
sum_edges sum_edges
mean_nodes mean_nodes
mean_edges mean_edges
max_nodes
max_edges
...@@ -12,7 +12,8 @@ from . import backend as F ...@@ -12,7 +12,8 @@ from . import backend as F
from . import utils from . import utils
__all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split', __all__ = ['BatchedDGLGraph', 'batch', 'unbatch', 'split',
'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges'] 'sum_nodes', 'sum_edges', 'mean_nodes', 'mean_edges',
'max_nodes', 'max_edges']
class BatchedDGLGraph(DGLGraph): class BatchedDGLGraph(DGLGraph):
"""Class for batched DGL graphs. """Class for batched DGL graphs.
...@@ -725,3 +726,93 @@ def mean_edges(graph, feat, weight=None): ...@@ -725,3 +726,93 @@ def mean_edges(graph, feat, weight=None):
sum_edges sum_edges
""" """
return _mean_on(graph, 'edges', feat, weight) return _mean_on(graph, 'edges', feat, weight)
def _max_on(graph, typestr, feat):
"""Internal function to take elementwise maximum
over node or edge features.
Parameters
----------
graph : DGLGraph
The graph.
typestr : str
'nodes' or 'edges'
feat : str
The feature field name.
Returns
-------
Tensor
The (weighted) summed node or edge features.
"""
data_attr, batch_num_objs_attr, _ = READOUT_ON_ATTRS[typestr]
data = getattr(graph, data_attr)
feat = data[feat]
if isinstance(graph, BatchedDGLGraph):
batch_num_objs = getattr(graph, batch_num_objs_attr)
max_readout_list = []
first = 0
for num_obj in batch_num_objs:
if num_obj == 0:
max_readout_list.append(F.zeros(F.shape(feat)[1:],
F.dtype(feat),
F.context(feat)))
continue
max_readout_list.append(F.max(feat[first:first+num_obj], 0))
first += num_obj
return F.stack(max_readout_list, 0)
else:
return F.max(feat, 0)
def max_nodes(graph, feat):
"""Take elementwise maximum over all the values of node field
:attr:`feat` in :attr:`graph`
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : str
The feature field.
Returns
-------
tensor
The tensor obtained.
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no nodes,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _max_on(graph, 'nodes', feat)
def max_edges(graph, feat):
"""Take elementwise maximum over all the values of edge field
:attr:`feat` in :attr:`graph`
Parameters
----------
graph : DGLGraph or BatchedDGLGraph
The graph.
feat : str
The feature field.
Returns
-------
tensor
The tensor obtained.
Notes
-----
If graph is a :class:`BatchedDGLGraph` object, a stacked tensor is
returned instead, i.e. having an extra first dimension.
Each row of the stacked tensor contains the readout result of
corresponding example in the batch. If an example has no edges,
a zero tensor with the same shape is returned at the corresponding row.
"""
return _max_on(graph, 'edges', feat)
...@@ -4,6 +4,7 @@ from __future__ import absolute_import ...@@ -4,6 +4,7 @@ from __future__ import absolute_import
from collections import defaultdict from collections import defaultdict
import dgl import dgl
import networkx as nx
from .base import ALL, is_all, DGLError from .base import ALL, is_all, DGLError
from . import backend as F from . import backend as F
from . import init from . import init
...@@ -1137,7 +1138,9 @@ class DGLGraph(object): ...@@ -1137,7 +1138,9 @@ class DGLGraph(object):
Parameters Parameters
---------- ----------
nx_graph : networkx.DiGraph nx_graph : networkx.DiGraph
The nx graph If the node labels of ``nx_graph`` are not consecutive
integers, its nodes will be relabeled using consecutive integers.
The new node ordering will inherit that of ``sorted(nx_graph.nodes())``
node_attrs : iterable of str, optional node_attrs : iterable of str, optional
The node attributes needs to be copied. The node attributes needs to be copied.
edge_attrs : iterable of str, optional edge_attrs : iterable of str, optional
...@@ -1165,6 +1168,16 @@ class DGLGraph(object): ...@@ -1165,6 +1168,16 @@ class DGLGraph(object):
[2., 2., 2., 2.], [2., 2., 2., 2.],
[1., 1., 1., 1.]]) [1., 1., 1., 1.]])
""" """
# Relabel nodes using consecutive integers
nx_graph = nx.convert_node_labels_to_integers(nx_graph, ordering='sorted')
# With to_directed we will get a directed version of the original networkx
# graph, with the original nodes, edges and their attributes preserved.
# This is particularly helpful when we are also converting the edge attributes
# as the reversed edges (u, v) will be created with the same attributes as the
# original edges (v, u).
if not nx_graph.is_directed():
nx_graph = nx_graph.to_directed()
self.clear() self.clear()
self._graph.from_networkx(nx_graph) self._graph.from_networkx(nx_graph)
self._node_frame.add_rows(self.number_of_nodes()) self._node_frame.add_rows(self.number_of_nodes())
...@@ -1194,7 +1207,12 @@ class DGLGraph(object): ...@@ -1194,7 +1207,12 @@ class DGLGraph(object):
# None here serves as placeholder to be replaced by feature with # None here serves as placeholder to be replaced by feature with
# corresponding edge id # corresponding edge id
if has_edge_id: if has_edge_id:
num_edges = self.number_of_edges()
for _, _, attrs in nx_graph.edges(data=True): for _, _, attrs in nx_graph.edges(data=True):
if attrs['id'] >= num_edges:
raise DGLError('Expect the pre-specified edge ids to be'
' smaller than the number of edges --'
' {}, got {}.'.format(num_edges, attrs['id']))
for key in edge_attrs: for key in edge_attrs:
attr_dict[key][attrs['id']] = attrs[key] attr_dict[key][attrs['id']] = attrs[key]
else: else:
...@@ -1204,6 +1222,9 @@ class DGLGraph(object): ...@@ -1204,6 +1222,9 @@ class DGLGraph(object):
for key in edge_attrs: for key in edge_attrs:
attr_dict[key][eid] = attrs[key] attr_dict[key][eid] = attrs[key]
for attr in edge_attrs: for attr in edge_attrs:
for val in attr_dict[attr]:
if val is None:
raise DGLError('Not all edges have attribute {}.'.format(attr))
self._edge_frame[attr] = _batcher(attr_dict[attr]) self._edge_frame[attr] = _batcher(attr_dict[attr])
def from_scipy_sparse_matrix(self, spmat): def from_scipy_sparse_matrix(self, spmat):
......
...@@ -667,7 +667,10 @@ class GraphIndex(object): ...@@ -667,7 +667,10 @@ class GraphIndex(object):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph() nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph)) else nx.DiGraph(nx_graph))
else: else:
nx_graph = nx_graph.to_directed() if not nx_graph.is_directed():
# to_directed creates a deep copy of the networkx graph even if
# the original graph is already directed and we do not want to do it.
nx_graph = nx_graph.to_directed()
num_nodes = nx_graph.number_of_nodes() num_nodes = nx_graph.number_of_nodes()
self.add_nodes(num_nodes) self.add_nodes(num_nodes)
......
...@@ -587,7 +587,10 @@ class ImmutableGraphIndex(object): ...@@ -587,7 +587,10 @@ class ImmutableGraphIndex(object):
nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph() nx_graph = (nx.MultiDiGraph(nx_graph) if self.is_multigraph()
else nx.DiGraph(nx_graph)) else nx.DiGraph(nx_graph))
else: else:
nx_graph = nx_graph.to_directed() if not nx_graph.is_directed():
# to_directed creates a deep copy of the networkx graph even if
# the original graph is already directed and we do not want to do it.
nx_graph = nx_graph.to_directed()
assert nx_graph.number_of_edges() > 0, "can't create an empty immutable graph" assert nx_graph.number_of_edges() > 0, "can't create an empty immutable graph"
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# readonly graph support. # readonly graph support.
import backend as F import backend as F
import dgl import dgl
import networkx as nx
from dgl import DGLGraph from dgl import DGLGraph
from collections import defaultdict as ddict from collections import defaultdict as ddict
...@@ -242,6 +243,25 @@ def test_nx_conversion(): ...@@ -242,6 +243,25 @@ def test_nx_conversion():
edge_feat = F.cat(edge_feat, 0) edge_feat = F.cat(edge_feat, 0)
assert F.allclose(g.edata['e1'], edge_feat) assert F.allclose(g.edata['e1'], edge_feat)
# Test converting from a networkx graph whose nodes are
# not labeled with consecutive-integers.
nxg = nx.cycle_graph(5)
nxg.remove_nodes_from([0, 4])
for u in nxg.nodes():
nxg.node[u]['h'] = F.tensor([u])
for u, v, d in nxg.edges(data=True):
d['h'] = F.tensor([u, v])
g = dgl.DGLGraph()
g.from_networkx(nxg, node_attrs=['h'], edge_attrs=['h'])
assert g.number_of_nodes() == 3
assert g.number_of_edges() == 4
assert g.has_edge_between(0, 1)
assert g.has_edge_between(1, 2)
assert F.allclose(g.ndata['h'], F.tensor([[1.], [2.], [3.]]))
assert F.allclose(g.edata['h'], F.tensor([[1., 2.], [1., 2.],
[2., 3.], [2., 3.]]))
def test_batch_send(): def test_batch_send():
g = generate_graph() g = generate_graph()
def _fmsg(edges): def _fmsg(edges):
......
...@@ -19,6 +19,9 @@ def test_simple_readout(): ...@@ -19,6 +19,9 @@ def test_simple_readout():
me1 = F.mean(e1, 0) # edge means me1 = F.mean(e1, 0) # edge means
w1 = F.randn((3,)) w1 = F.randn((3,))
w2 = F.randn((4,)) w2 = F.randn((4,))
max1 = F.max(n1, 0)
max2 = F.max(n2, 0)
maxe1 = F.max(e1, 0)
ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) ws1 = F.sum(n1 * F.unsqueeze(w1, 1), 0)
ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0) ws2 = F.sum(n2 * F.unsqueeze(w2, 1), 0)
wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0) wm1 = F.sum(n1 * F.unsqueeze(w1, 1), 0) / F.sum(F.unsqueeze(w1, 1), 0)
...@@ -35,20 +38,26 @@ def test_simple_readout(): ...@@ -35,20 +38,26 @@ def test_simple_readout():
assert F.allclose(dgl.mean_nodes(g1, 'x'), m1) assert F.allclose(dgl.mean_nodes(g1, 'x'), m1)
assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) assert F.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1)
assert F.allclose(dgl.mean_edges(g1, 'x'), me1) assert F.allclose(dgl.mean_edges(g1, 'x'), me1)
assert F.allclose(dgl.max_nodes(g1, 'x'), max1)
assert F.allclose(dgl.max_edges(g1, 'x'), maxe1)
g = dgl.batch([g1, g2]) g = dgl.batch([g1, g2])
s = dgl.sum_nodes(g, 'x') s = dgl.sum_nodes(g, 'x')
m = dgl.mean_nodes(g, 'x') m = dgl.mean_nodes(g, 'x')
max_bg = dgl.max_nodes(g, 'x')
assert F.allclose(s, F.stack([s1, s2], 0)) assert F.allclose(s, F.stack([s1, s2], 0))
assert F.allclose(m, F.stack([m1, m2], 0)) assert F.allclose(m, F.stack([m1, m2], 0))
assert F.allclose(max_bg, F.stack([max1, max2], 0))
ws = dgl.sum_nodes(g, 'x', 'w') ws = dgl.sum_nodes(g, 'x', 'w')
wm = dgl.mean_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w')
assert F.allclose(ws, F.stack([ws1, ws2], 0)) assert F.allclose(ws, F.stack([ws1, ws2], 0))
assert F.allclose(wm, F.stack([wm1, wm2], 0)) assert F.allclose(wm, F.stack([wm1, wm2], 0))
s = dgl.sum_edges(g, 'x') s = dgl.sum_edges(g, 'x')
m = dgl.mean_edges(g, 'x') m = dgl.mean_edges(g, 'x')
max_bg_e = dgl.max_edges(g, 'x')
assert F.allclose(s, F.stack([se1, F.zeros(5)], 0)) assert F.allclose(s, F.stack([se1, F.zeros(5)], 0))
assert F.allclose(m, F.stack([me1, F.zeros(5)], 0)) assert F.allclose(m, F.stack([me1, F.zeros(5)], 0))
assert F.allclose(max_bg_e, F.stack([maxe1, F.zeros(5)], 0))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment