Commit 174c1d55 authored by zzhang-cn's avatar zzhang-cn
Browse files

add edge repr func; modify mgcn.py

parent eb507e4f
......@@ -53,6 +53,7 @@ class EdgeUpdateModule(nn.Module):
new_he = self.net1(src['hv']) + self.net2(dst['hv']) + self.net3(edge['he'])
return {'he' : new_he}
# TODO: we don't need this one anymore
class EdgeModule(nn.Module):
def __init__(self, he_dims):
# use a flag to trigger either message module or edge update module.
......@@ -70,7 +71,8 @@ class EdgeModule(nn.Module):
def train(g):
# TODO(minjie): finish the complete training algorithm.
g = dgl.DGLGraph(g)
g.register_message_func(EdgeModule())
g.register_message_func(MessageModule())
g.register_edge_func(EdgeUpdateModule())
g.register_update_func(NodeUpdateModule())
# TODO(minjie): init hv and he
num_iter = 10
......@@ -78,4 +80,4 @@ def train(g):
# The first call triggers message function and update all the nodes.
g.update_all()
# The second sendall updates all the edge features.
g.send_all()
# g.send_all()
......@@ -10,8 +10,10 @@ from dgl.backend import Tensor
import dgl.utils as utils
__MSG__ = "__msg__"
__REPR__ = "__repr__"
__E_REPR__ = "__e_repr__"
__N_REPR__ = "__n_repr__"
__MFUNC__ = "__mfunc__"
__EFUNC__ = "__efunc__"
__UFUNC__ = "__ufunc__"
class DGLGraph(DiGraph):
......@@ -30,6 +32,7 @@ class DGLGraph(DiGraph):
super(DGLGraph, self).__init__(graph_data, **attr)
self.m_func = None
self.u_func = None
self.e_func = None
self.readout_func = None
def init_reprs(self, h_init=None):
......@@ -38,14 +41,14 @@ class DGLGraph(DiGraph):
for n in self.nodes:
self.set_repr(n, h_init)
def set_repr(self, u, h_u, name=__REPR__):
def set_repr(self, u, h_u, name=__N_REPR__):
print("[DEPRECATED]: please directly set node attrs "
"(e.g. g.nodes[node]['x'] = val).")
assert u in self.nodes
kwarg = {name: h_u}
self.add_node(u, **kwarg)
def get_repr(self, u, name=__REPR__):
def get_repr(self, u, name=__N_REPR__):
print("[DEPRECATED]: please directly get node attrs "
"(e.g. g.nodes[node]['x']).")
assert u in self.nodes
......@@ -58,7 +61,7 @@ class DGLGraph(DiGraph):
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the new edge representations (the same concept as messages)
It computes the representation of a message
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
......@@ -93,6 +96,48 @@ class DGLGraph(DiGraph):
for e in edges:
self.edges[e][__MFUNC__] = message_func
def register_edge_func(self, edge_func, edges='all', batchable=False):
"""Register computation on edges.
The edge function should be compatible with following signature:
(node_reprs, node_reprs, edge_reprs) -> edge_reprs
It computes the new edge representations (the same concept as messages)
using the representations of the source node, target node and the edge
itself. All node_reprs and edge_reprs are dictionaries.
Parameters
----------
edge_func : callable
Message function on the edge.
edges : str, pair of nodes, pair of containers, pair of tensors
The edges for which the message function is registered. Default is
registering for all the edges. Registering for multiple edges is
supported.
batchable : bool
Whether the provided message function allows batch computing.
Examples
--------
Register for all edges.
>>> g.register_edge_func(efunc)
Register for a specific edge.
>>> g.register_edge_func(efunc, (u, v))
Register for multiple edges.
>>> u = [u1, u2, u3, ...]
>>> v = [v1, v2, v3, ...]
>>> g.register_edge_func(mfunc, (u, v))
"""
if edges == 'all':
self.e_func = edge_func
else:
for e in edges:
self.edges[e][__EFUNC__] = edge_func
def register_update_func(self, update_func, nodes='all', batchable=False):
"""Register computation on nodes.
......@@ -198,6 +243,24 @@ class DGLGraph(DiGraph):
m = f_msg(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__MSG__] = m
def update_edge(self, u, v):
"""Update representation on edge u->v
Parameters
----------
u : node, container or tensor
The source node(s).
v : node, container or tensor
The destination node(s).
"""
# TODO(minjie): tensorize the loop.
for uu, vv in utils.edge_iter(u, v):
f_edge = self.edges[uu, vv].get(__EFUNC__, self.m_func)
assert f_edge is not None, \
"edge function not registered for edge (%s->%s)" % (uu, vv)
m = f_edge(self.nodes[uu], self.nodes[vv], self.edges[uu, vv])
self.edges[uu, vv][__E_REPR__] = m
def recvfrom(self, u, preds=None):
"""Trigger the update function on node u.
......@@ -292,6 +355,9 @@ class DGLGraph(DiGraph):
v = [vv for _, vv in self.edges]
self.sendto(u, v)
self.recvfrom(list(self.nodes()))
# TODO(zz): this is a hack
if self.e_func:
self.update_edge(u, v)
def propagate(self, iterator='bfs', **kwargs):
"""Propagate messages and update nodes using iterator.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment