Commit 8e2e68df authored by Minjie Wang's avatar Minjie Wang
Browse files

fix bugs for topdown.py models

parent 6858ed54
......@@ -159,7 +159,7 @@ class DGLGraph(DiGraph):
"""
self.readout_func = readout_func
def readout(self, nodes='all', edges='all'):
def readout(self, nodes='all', edges='all', **kwargs):
"""Trigger the readout function on the specified nodes/edges.
Parameters
......@@ -168,15 +168,17 @@ class DGLGraph(DiGraph):
The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from.
kwargs : keyword arguments, optional
Arguments for the readout function.
"""
nodes = self._nodes_or_all(nodes)
edges = self._nodes_or_all(nodes)
edges = self._edges_or_all(edges)
assert self.readout_func is not None, \
"Readout function is not registered."
# TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges]
return self.readout_func(nstates, estates)
return self.readout_func(nstates, estates, **kwargs)
def sendto(self, u, v):
"""Trigger the message function on edge u->v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment