Commit 174c1d55 authored by zzhang-cn's avatar zzhang-cn
Browse files

add edge repr func; modify mgcn.py

parent eb507e4f
...@@ -53,6 +53,7 @@ class EdgeUpdateModule(nn.Module): ...@@ -53,6 +53,7 @@ class EdgeUpdateModule(nn.Module):
new_he = self.net1(src['hv']) + self.net2(dst['hv']) + self.net3(edge['he']) new_he = self.net1(src['hv']) + self.net2(dst['hv']) + self.net3(edge['he'])
return {'he' : new_he} return {'he' : new_he}
# TODO: we don't need this one anymore
class EdgeModule(nn.Module): class EdgeModule(nn.Module):
def __init__(self, he_dims): def __init__(self, he_dims):
# use a flag to trigger either message module or edge update module. # use a flag to trigger either message module or edge update module.
...@@ -70,7 +71,8 @@ class EdgeModule(nn.Module): ...@@ -70,7 +71,8 @@ class EdgeModule(nn.Module):
def train(g): def train(g):
# TODO(minjie): finish the complete training algorithm. # TODO(minjie): finish the complete training algorithm.
g = dgl.DGLGraph(g) g = dgl.DGLGraph(g)
g.register_message_func(EdgeModule()) g.register_message_func(MessageModule())
g.register_edge_func(EdgeUpdateModule())
g.register_update_func(NodeUpdateModule()) g.register_update_func(NodeUpdateModule())
# TODO(minjie): init hv and he # TODO(minjie): init hv and he
num_iter = 10 num_iter = 10
...@@ -78,4 +80,4 @@ def train(g): ...@@ -78,4 +80,4 @@ def train(g):
# The first call triggers message function and update all the nodes. # The first call triggers message function and update all the nodes.
g.update_all() g.update_all()
# The second sendall updates all the edge features. # The second sendall updates all the edge features.
g.send_all() # g.send_all()
...@@ -10,8 +10,10 @@ from dgl.backend import Tensor ...@@ -10,8 +10,10 @@ from dgl.backend import Tensor
import dgl.utils as utils import dgl.utils as utils
__MSG__ = "__msg__" __MSG__ = "__msg__"
__REPR__ = "__repr__" __E_REPR__ = "__e_repr__"
__N_REPR__ = "__n_repr__"
__MFUNC__ = "__mfunc__" __MFUNC__ = "__mfunc__"
__EFUNC__ = "__efunc__"
__UFUNC__ = "__ufunc__" __UFUNC__ = "__ufunc__"
class DGLGraph(DiGraph): class DGLGraph(DiGraph):
...@@ -30,6 +32,7 @@ class DGLGraph(DiGraph): ...@@ -30,6 +32,7 @@ class DGLGraph(DiGraph):
super(DGLGraph, self).__init__(graph_data, **attr) super(DGLGraph, self).__init__(graph_data, **attr)
self.m_func = None self.m_func = None
self.u_func = None self.u_func = None
self.e_func = None
self.readout_func = None self.readout_func = None
def init_reprs(self, h_init=None): def init_reprs(self, h_init=None):
...@@ -38,14 +41,14 @@ class DGLGraph(DiGraph): ...@@ -38,14 +41,14 @@ class DGLGraph(DiGraph):
for n in self.nodes: for n in self.nodes:
self.set_repr(n, h_init) self.set_repr(n, h_init)
def set_repr(self, u, h_u, name=__REPR__): def set_repr(self, u, h_u, name=__N_REPR__):
print("[DEPRECATED]: please directly set node attrs " print("[DEPRECATED]: please directly set node attrs "
"(e.g. g.nodes[node]['x'] = val).") "(e.g. g.nodes[node]['x'] = val).")
assert u in self.nodes assert u in self.nodes
kwarg = {name: h_u} kwarg = {name: h_u}
self.add_node(u, **kwarg) self.add_node(u, **kwarg)
def get_repr(self, u, name=__REPR__): def get_repr(self, u, name=__N_REPR__):
print("[DEPRECATED]: please directly get node attrs " print("[DEPRECATED]: please directly get node attrs "
"(e.g. g.nodes[node]['x']).") "(e.g. g.nodes[node]['x']).")
assert u in self.nodes assert u in self.nodes
...@@ -58,7 +61,7 @@ class DGLGraph(DiGraph): ...@@ -58,7 +61,7 @@ class DGLGraph(DiGraph):
(node_reprs, node_reprs, edge_reprs) -> edge_reprs (node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the new edge representations (the same concept as messages) It computes the representation of a message
using the representations of the source node, target node and the edge using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries. itself. All node_reprs and edge_reprs are dictionaries.
...@@ -93,6 +96,48 @@ class DGLGraph(DiGraph): ...@@ -93,6 +96,48 @@ class DGLGraph(DiGraph):
for e in edges: for e in edges:
self.edges[e][__MFUNC__] = message_func self.edges[e][__MFUNC__] = message_func
def register_edge_func(self, edge_func, edges='all', batchable=False):
"""Register computation on edges.
The edge function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the new edge representations (the same concept as messages)
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
Parameters
----------
edge_func : callable
Message function on the edge.
edges : str, pair of nodes, pair of containers, pair of tensors
The edges for which the message function is registered. Default is
registering for all the edges. Registering for multiple edges is
supported.
batchable : bool
Whether the provided message function allows batch computing.
Examples
--------
Register for all edges.
>>> g.register_edge_func(efunc)
Register for a specific edge.
>>> g.register_edge_func(efunc, (u, v))
Register for multiple edges.
>>> u = [u1, u2, u3, ...]
>>> v = [v1, v2, v3, ...]
>>> g.register_edge_func(mfunc, (u, v))
"""
if edges == 'all':
self.e_func = edge_func
else:
for e in edges:
self.edges[e][__EFUNC__] = edge_func
def register_update_func(self, update_func, nodes='all', batchable=False): def register_update_func(self, update_func, nodes='all', batchable=False):
"""Register computation on nodes. """Register computation on nodes.
...@@ -198,6 +243,24 @@ class DGLGraph(DiGraph): ...@@ -198,6 +243,24 @@ class DGLGraph(DiGraph):
m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv]) m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__MSG__] = m self.edges[uu, vv][__MSG__] = m
def update_edge(self, u, v):
"""Update representation on edge u->v
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
"""
# TODO(minjie): tensorize the loop.
for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(__EFUNC__, self.m_func)
assert f_edge is not None, \
"edge function not registered for edge (%s->%s)" % (uu, vv)
m = f_edge(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__E_REPR__] = m
def recvfrom(self, u, preds=None): def recvfrom(self, u, preds=None):
"""Trigger the update function on node u. """Trigger the update function on node u.
...@@ -292,6 +355,9 @@ class DGLGraph(DiGraph): ...@@ -292,6 +355,9 @@ class DGLGraph(DiGraph):
v = [vv for _, vv in self.edges] v = [vv for _, vv in self.edges]
self.sendto(u, v) self.sendto(u, v)
self.recvfrom(list(self.nodes())) self.recvfrom(list(self.nodes()))
# TODO(zz): this is a hack
if self.e_func:
self.update_edge(u, v)
def propagate(self, iterator='bfs', **kwargs): def propagate(self, iterator='bfs', **kwargs):
"""Propagate messages and update nodes using iterator. """Propagate messages and update nodes using iterator.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment