Unverified Commit b355d1ed authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[API] Apply nodes & apply edges (#117)

parent 1a2b306f
......@@ -68,7 +68,7 @@ class GCN(nn.Module):
self.g.apply_nodes(apply_node_func=
lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h')
return self.g.ndata.pop('h')
def main(args):
# load and preprocess dataset
......
......@@ -58,8 +58,8 @@ class DGLGraph(object):
# registered functions
self._message_func = None
self._reduce_func = None
self._edge_func = None
self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None):
"""Add nodes.
......@@ -815,110 +815,79 @@ class DGLGraph(object):
"""
return self._edge_frame.pop(key)
def register_edge_func(self, edge_func):
"""Register global edge update function.
Parameters
----------
edge_func : callable
Message function on the edge.
"""
self._edge_func = edge_func
def register_message_func(self, message_func):
def register_message_func(self, func):
"""Register global message function.
Parameters
----------
message_func : callable
func : callable
Message function on the edge.
"""
self._message_func = message_func
self._message_func = func
def register_reduce_func(self, reduce_func):
def register_reduce_func(self, func):
"""Register global message reduce function.
Parameters
----------
reduce_func : str or callable
func : str or callable
Reduce function on incoming edges.
"""
self._reduce_func = reduce_func
self._reduce_func = func
def register_apply_node_func(self, apply_node_func):
def register_apply_node_func(self, func):
"""Register global node apply function.
Parameters
----------
apply_node_func : callable
Apply function on the node.
func : callable
Apply function on the node.
"""
self._apply_node_func = apply_node_func
self._apply_node_func = func
def apply_nodes(self, v=ALL, apply_node_func="default"):
"""Apply the function on node representations.
Applying a None function will be ignored.
def register_apply_edge_func(self, func):
"""Register global edge apply function.
Parameters
----------
v : int, iterable of int, tensor, optional
The node id(s).
apply_node_func : callable
The apply node function.
edge_func : callable
Apply function on the edge.
"""
self._apply_nodes(v, apply_node_func)
self._apply_edge_func = func
def _apply_nodes(self, v, apply_node_func="default", reduce_accum=None):
"""Internal apply nodes
def apply_nodes(self, func="default", v=ALL):
"""Apply the function on the node features.
Applying a None function will be ignored.
Parameters
----------
reduce_accum: dict-like
The output of reduce func
func : callable, optional
The UDF applied on the node features.
v : int, iterable of int, tensor, optional
The node id(s).
"""
if apply_node_func == "default":
apply_node_func = self._apply_node_func
if not apply_node_func:
# Skip none function call.
if reduce_accum is not None:
# write reduce result back
self.set_n_repr(reduce_accum, v)
return
# take out current node repr
curr_repr = self.get_n_repr(v)
if reduce_accum is not None:
# merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr)
nb = NodeBatch(self, v, curr_repr)
new_repr = apply_node_func(nb)
if reduce_accum is not None:
# merge new node_repr with reduce output
reduce_accum.update(new_repr)
new_repr = reduce_accum
self.set_n_repr(new_repr, v)
def send(self, edges=ALL, message_func="default"):
"""Send messages along the given edges.
self._internal_apply_nodes(v, func)
def apply_edges(self, func="default", edges=ALL):
"""Apply the function on the edge features.
Parameters
----------
func : callable, optional
The UDF applied on the edge features.
edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
message_func : callable
The message function.
Notes
-----
On multigraphs, if u and v are specified, then the messages will be sent
along all edges between u and v.
On multigraphs, if u and v are specified, then all the edges
between u and v will be updated.
"""
if message_func == "default":
message_func = self._message_func
assert message_func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if func == "default":
func = self._apply_edge_func
assert func is not None
if is_all(edges):
eid = ALL
......@@ -938,29 +907,29 @@ class DGLGraph(object):
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
msgs = message_func(eb)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
self.set_e_repr(func(eb), eid)
def update_edges(self, edges=ALL, edge_func="default"):
"""Update features on the given edges.
def send(self, edges, message_func="default"):
"""Send messages along the given edges.
Parameters
----------
edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges.
edge_func : callable
The update function.
tensor of edge ids.
message_func : callable
The message function.
Notes
-----
On multigraphs, if u and v are specified, then all the edges
between u and v will be updated.
On multigraphs, if u and v are specified, then the messages will be sent
along all edges between u and v.
"""
if edge_func == "default":
edge_func = self._edge_func
assert edge_func is not None
if message_func == "default":
message_func = self._message_func
assert message_func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if is_all(edges):
eid = ALL
......@@ -980,7 +949,9 @@ class DGLGraph(object):
dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data)
self.set_e_repr(edge_func(eb), eid)
msgs = message_func(eb)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
def recv(self,
u,
......@@ -1008,7 +979,7 @@ class DGLGraph(object):
reduce_func = BundledReduceFunction(reduce_func)
self._batch_recv(u, reduce_func)
# optional apply nodes
self.apply_nodes(u, apply_node_func)
self.apply_nodes(apply_node_func, u)
def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0:
......@@ -1153,7 +1124,7 @@ class DGLGraph(object):
accum = executor.run()
unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=accum)
self._internal_apply_nodes(unique_v, apply_node_func, reduce_accum=accum)
def pull(self,
v,
......@@ -1179,7 +1150,7 @@ class DGLGraph(object):
uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv((uu, vv), message_func, reduce_func, apply_node_func=None)
unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func)
self.apply_nodes(apply_node_func, unique_v)
def push(self,
u,
......@@ -1232,7 +1203,7 @@ class DGLGraph(object):
"update_all", self, message_func=message_func, reduce_func=reduce_func)
if executor:
new_reprs = executor.run()
self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs)
self._internal_apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs)
else:
self.send(ALL, message_func)
self.recv(ALL, reduce_func, apply_node_func)
......@@ -1474,3 +1445,32 @@ class DGLGraph(object):
else:
edges = F.Tensor(edges)
return edges[e_mask]
def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None):
"""Internal apply nodes
Parameters
----------
reduce_accum: dict-like
The output of reduce func
"""
if apply_node_func == "default":
apply_node_func = self._apply_node_func
if not apply_node_func:
# Skip none function call.
if reduce_accum is not None:
# write reduce result back
self.set_n_repr(reduce_accum, v)
return
# take out current node repr
curr_repr = self.get_n_repr(v)
if reduce_accum is not None:
# merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr)
nb = NodeBatch(self, v, curr_repr)
new_repr = apply_node_func(nb)
if reduce_accum is not None:
# merge new node_repr with reduce output
reduce_accum.update(new_repr)
new_repr = reduce_accum
self.set_n_repr(new_repr, v)
......@@ -168,17 +168,17 @@ def test_batch_recv():
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear()
def test_update_edges():
def test_apply_edges():
def _upd(edges):
return {'w' : edges.data['w'] * 2}
g = generate_graph()
g.register_edge_func(_upd)
g.register_apply_edge_func(_upd)
old = g.edata['w']
g.update_edges()
g.apply_edges()
assert th.allclose(old * 2, g.edata['w'])
u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_edges((u, v), lambda edges : {'w' : edges.data['w'] * 0.})
g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
eid = g.edge_ids(u, v)
assert th.allclose(g.edata['w'][eid], th.zeros((6, D)))
......@@ -392,7 +392,7 @@ if __name__ == '__main__':
test_batch_setter_autograd()
test_batch_send()
test_batch_recv()
test_update_edges()
test_apply_edges()
test_update_routines()
test_reduce_0deg()
test_pull_0deg()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment