Commit 8e2e68df authored by Minjie Wang's avatar Minjie Wang
Browse files

fix bugs for topdown.py models

parent 6858ed54
...@@ -159,7 +159,7 @@ class DGLGraph(DiGraph): ...@@ -159,7 +159,7 @@ class DGLGraph(DiGraph):
""" """
self.readout_func = readout_func self.readout_func = readout_func
def readout(self, nodes='all', edges='all'): def readout(self, nodes='all', edges='all', **kwargs):
"""Trigger the readout function on the specified nodes/edges. """Trigger the readout function on the specified nodes/edges.
Parameters Parameters
...@@ -168,15 +168,17 @@ class DGLGraph(DiGraph): ...@@ -168,15 +168,17 @@ class DGLGraph(DiGraph):
The nodes to get reprs from. The nodes to get reprs from.
edges : str, pair of nodes, pair of containers or pair of tensors edges : str, pair of nodes, pair of containers or pair of tensors
The edges to get reprs from. The edges to get reprs from.
kwargs : keyword arguments, optional
Arguments for the readout function.
""" """
nodes = self._nodes_or_all(nodes) nodes = self._nodes_or_all(nodes)
edges = self._nodes_or_all(nodes) edges = self._edges_or_all(edges)
assert self.readout_func is not None, \ assert self.readout_func is not None, \
"Readout function is not registered." "Readout function is not registered."
# TODO(minjie): tensorize following loop. # TODO(minjie): tensorize following loop.
nstates = [self.nodes[n] for n in nodes] nstates = [self.nodes[n] for n in nodes]
estates = [self.edges[e] for e in edges] estates = [self.edges[e] for e in edges]
return self.readout_func(nstates, estates) return self.readout_func(nstates, estates, **kwargs)
def sendto(self, u, v): def sendto(self, u, v):
"""Trigger the message function on edge u->v """Trigger the message function on edge u->v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment