"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "08b60eb1628ef91a29a14de33b046f8f19808531"
Unverified Commit b355d1ed authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[API] Apply nodes & apply edges (#117)

parent 1a2b306f
...@@ -68,7 +68,7 @@ class GCN(nn.Module): ...@@ -68,7 +68,7 @@ class GCN(nn.Module):
self.g.apply_nodes(apply_node_func= self.g.apply_nodes(apply_node_func=
lambda nodes: {'h': self.dropout(nodes.data['h'])}) lambda nodes: {'h': self.dropout(nodes.data['h'])})
self.g.update_all(gcn_msg, gcn_reduce, layer) self.g.update_all(gcn_msg, gcn_reduce, layer)
return self.g.pop_n_repr('h') return self.g.ndata.pop('h')
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
......
...@@ -58,8 +58,8 @@ class DGLGraph(object): ...@@ -58,8 +58,8 @@ class DGLGraph(object):
# registered functions # registered functions
self._message_func = None self._message_func = None
self._reduce_func = None self._reduce_func = None
self._edge_func = None
self._apply_node_func = None self._apply_node_func = None
self._apply_edge_func = None
def add_nodes(self, num, reprs=None): def add_nodes(self, num, reprs=None):
"""Add nodes. """Add nodes.
...@@ -815,110 +815,79 @@ class DGLGraph(object): ...@@ -815,110 +815,79 @@ class DGLGraph(object):
""" """
return self._edge_frame.pop(key) return self._edge_frame.pop(key)
def register_edge_func(self, edge_func): def register_message_func(self, func):
"""Register global edge update function.
Parameters
----------
edge_func : callable
Message function on the edge.
"""
self._edge_func = edge_func
def register_message_func(self, message_func):
"""Register global message function. """Register global message function.
Parameters Parameters
---------- ----------
message_func : callable func : callable
Message function on the edge. Message function on the edge.
""" """
self._message_func = message_func self._message_func = func
def register_reduce_func(self, reduce_func): def register_reduce_func(self, func):
"""Register global message reduce function. """Register global message reduce function.
Parameters Parameters
---------- ----------
reduce_func : str or callable func : str or callable
Reduce function on incoming edges. Reduce function on incoming edges.
""" """
self._reduce_func = reduce_func self._reduce_func = func
def register_apply_node_func(self, apply_node_func): def register_apply_node_func(self, func):
"""Register global node apply function. """Register global node apply function.
Parameters Parameters
---------- ----------
apply_node_func : callable func : callable
Apply function on the node. Apply function on the node.
""" """
self._apply_node_func = apply_node_func self._apply_node_func = func
def apply_nodes(self, v=ALL, apply_node_func="default"): def register_apply_edge_func(self, func):
"""Apply the function on node representations. """Register global edge apply function.
Applying a None function will be ignored.
Parameters Parameters
---------- ----------
v : int, iterable of int, tensor, optional edge_func : callable
The node id(s). Apply function on the edge.
apply_node_func : callable
The apply node function.
""" """
self._apply_nodes(v, apply_node_func) self._apply_edge_func = func
def _apply_nodes(self, v, apply_node_func="default", reduce_accum=None): def apply_nodes(self, func="default", v=ALL):
"""Internal apply nodes """Apply the function on the node features.
Applying a None function will be ignored.
Parameters Parameters
---------- ----------
reduce_accum: dict-like func : callable, optional
The output of reduce func The UDF applied on the node features.
v : int, iterable of int, tensor, optional
The node id(s).
""" """
if apply_node_func == "default": self._internal_apply_nodes(v, func)
apply_node_func = self._apply_node_func
if not apply_node_func: def apply_edges(self, func="default", edges=ALL):
# Skip none function call. """Apply the function on the edge features.
if reduce_accum is not None:
# write reduce result back
self.set_n_repr(reduce_accum, v)
return
# take out current node repr
curr_repr = self.get_n_repr(v)
if reduce_accum is not None:
# merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr)
nb = NodeBatch(self, v, curr_repr)
new_repr = apply_node_func(nb)
if reduce_accum is not None:
# merge new node_repr with reduce output
reduce_accum.update(new_repr)
new_repr = reduce_accum
self.set_n_repr(new_repr, v)
def send(self, edges=ALL, message_func="default"):
"""Send messages along the given edges.
Parameters Parameters
---------- ----------
func : callable, optional
The UDF applied on the edge features.
edges : edges, optional edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges. tensor of edge ids. The default value is all the edges.
message_func : callable
The message function.
Notes Notes
----- -----
On multigraphs, if u and v are specified, then the messages will be sent On multigraphs, if u and v are specified, then all the edges
along all edges between u and v. between u and v will be updated.
""" """
if message_func == "default": if func == "default":
message_func = self._message_func func = self._apply_edge_func
assert message_func is not None assert func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if is_all(edges): if is_all(edges):
eid = ALL eid = ALL
...@@ -938,29 +907,29 @@ class DGLGraph(object): ...@@ -938,29 +907,29 @@ class DGLGraph(object):
dst_data = self.get_n_repr(v) dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid), eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data) src_data, edge_data, dst_data)
msgs = message_func(eb) self.set_e_repr(func(eb), eid)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
def update_edges(self, edges=ALL, edge_func="default"): def send(self, edges, message_func="default"):
"""Update features on the given edges. """Send messages along the given edges.
Parameters Parameters
---------- ----------
edges : edges, optional edges : edges, optional
Edges can be a pair of endpoint nodes (u, v), or a Edges can be a pair of endpoint nodes (u, v), or a
tensor of edge ids. The default value is all the edges. tensor of edge ids.
edge_func : callable message_func : callable
The update function. The message function.
Notes Notes
----- -----
On multigraphs, if u and v are specified, then all the edges On multigraphs, if u and v are specified, then the messages will be sent
between u and v will be updated. along all edges between u and v.
""" """
if edge_func == "default": if message_func == "default":
edge_func = self._edge_func message_func = self._message_func
assert edge_func is not None assert message_func is not None
if isinstance(message_func, (tuple, list)):
message_func = BundledMessageFunction(message_func)
if is_all(edges): if is_all(edges):
eid = ALL eid = ALL
...@@ -980,7 +949,9 @@ class DGLGraph(object): ...@@ -980,7 +949,9 @@ class DGLGraph(object):
dst_data = self.get_n_repr(v) dst_data = self.get_n_repr(v)
eb = EdgeBatch(self, (u, v, eid), eb = EdgeBatch(self, (u, v, eid),
src_data, edge_data, dst_data) src_data, edge_data, dst_data)
self.set_e_repr(edge_func(eb), eid) msgs = message_func(eb)
self._msg_graph.add_edges(u, v)
self._msg_frame.append(msgs)
def recv(self, def recv(self,
u, u,
...@@ -1008,7 +979,7 @@ class DGLGraph(object): ...@@ -1008,7 +979,7 @@ class DGLGraph(object):
reduce_func = BundledReduceFunction(reduce_func) reduce_func = BundledReduceFunction(reduce_func)
self._batch_recv(u, reduce_func) self._batch_recv(u, reduce_func)
# optional apply nodes # optional apply nodes
self.apply_nodes(u, apply_node_func) self.apply_nodes(apply_node_func, u)
def _batch_recv(self, v, reduce_func): def _batch_recv(self, v, reduce_func):
if self._msg_frame.num_rows == 0: if self._msg_frame.num_rows == 0:
...@@ -1153,7 +1124,7 @@ class DGLGraph(object): ...@@ -1153,7 +1124,7 @@ class DGLGraph(object):
accum = executor.run() accum = executor.run()
unique_v = executor.recv_nodes unique_v = executor.recv_nodes
self._apply_nodes(unique_v, apply_node_func, reduce_accum=accum) self._internal_apply_nodes(unique_v, apply_node_func, reduce_accum=accum)
def pull(self, def pull(self,
v, v,
...@@ -1179,7 +1150,7 @@ class DGLGraph(object): ...@@ -1179,7 +1150,7 @@ class DGLGraph(object):
uu, vv, _ = self._graph.in_edges(v) uu, vv, _ = self._graph.in_edges(v)
self.send_and_recv((uu, vv), message_func, reduce_func, apply_node_func=None) self.send_and_recv((uu, vv), message_func, reduce_func, apply_node_func=None)
unique_v = F.unique(v.tousertensor()) unique_v = F.unique(v.tousertensor())
self.apply_nodes(unique_v, apply_node_func) self.apply_nodes(apply_node_func, unique_v)
def push(self, def push(self,
u, u,
...@@ -1232,7 +1203,7 @@ class DGLGraph(object): ...@@ -1232,7 +1203,7 @@ class DGLGraph(object):
"update_all", self, message_func=message_func, reduce_func=reduce_func) "update_all", self, message_func=message_func, reduce_func=reduce_func)
if executor: if executor:
new_reprs = executor.run() new_reprs = executor.run()
self._apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs) self._internal_apply_nodes(ALL, apply_node_func, reduce_accum=new_reprs)
else: else:
self.send(ALL, message_func) self.send(ALL, message_func)
self.recv(ALL, reduce_func, apply_node_func) self.recv(ALL, reduce_func, apply_node_func)
...@@ -1474,3 +1445,32 @@ class DGLGraph(object): ...@@ -1474,3 +1445,32 @@ class DGLGraph(object):
else: else:
edges = F.Tensor(edges) edges = F.Tensor(edges)
return edges[e_mask] return edges[e_mask]
def _internal_apply_nodes(self, v, apply_node_func="default", reduce_accum=None):
"""Internal apply nodes
Parameters
----------
reduce_accum: dict-like
The output of reduce func
"""
if apply_node_func == "default":
apply_node_func = self._apply_node_func
if not apply_node_func:
# Skip none function call.
if reduce_accum is not None:
# write reduce result back
self.set_n_repr(reduce_accum, v)
return
# take out current node repr
curr_repr = self.get_n_repr(v)
if reduce_accum is not None:
# merge current node_repr with reduce output
curr_repr = utils.HybridDict(reduce_accum, curr_repr)
nb = NodeBatch(self, v, curr_repr)
new_repr = apply_node_func(nb)
if reduce_accum is not None:
# merge new node_repr with reduce output
reduce_accum.update(new_repr)
new_repr = reduce_accum
self.set_n_repr(new_repr, v)
...@@ -168,17 +168,17 @@ def test_batch_recv(): ...@@ -168,17 +168,17 @@ def test_batch_recv():
assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)}) assert(reduce_msg_shapes == {(1, 3, D), (3, 1, D)})
reduce_msg_shapes.clear() reduce_msg_shapes.clear()
def test_update_edges(): def test_apply_edges():
def _upd(edges): def _upd(edges):
return {'w' : edges.data['w'] * 2} return {'w' : edges.data['w'] * 2}
g = generate_graph() g = generate_graph()
g.register_edge_func(_upd) g.register_apply_edge_func(_upd)
old = g.edata['w'] old = g.edata['w']
g.update_edges() g.apply_edges()
assert th.allclose(old * 2, g.edata['w']) assert th.allclose(old * 2, g.edata['w'])
u = th.tensor([0, 0, 0, 4, 5, 6]) u = th.tensor([0, 0, 0, 4, 5, 6])
v = th.tensor([1, 2, 3, 9, 9, 9]) v = th.tensor([1, 2, 3, 9, 9, 9])
g.update_edges((u, v), lambda edges : {'w' : edges.data['w'] * 0.}) g.apply_edges(lambda edges : {'w' : edges.data['w'] * 0.}, (u, v))
eid = g.edge_ids(u, v) eid = g.edge_ids(u, v)
assert th.allclose(g.edata['w'][eid], th.zeros((6, D))) assert th.allclose(g.edata['w'][eid], th.zeros((6, D)))
...@@ -392,7 +392,7 @@ if __name__ == '__main__': ...@@ -392,7 +392,7 @@ if __name__ == '__main__':
test_batch_setter_autograd() test_batch_setter_autograd()
test_batch_send() test_batch_send()
test_batch_recv() test_batch_recv()
test_update_edges() test_apply_edges()
test_update_routines() test_update_routines()
test_reduce_0deg() test_reduce_0deg()
test_pull_0deg() test_pull_0deg()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment