Commit 05547b37 authored by zzhang-cn's avatar zzhang-cn
Browse files

works now with tensor and nn

parent c2bdedc0
import networkx as nx
from networkx.classes.graph import Graph
from networkx.classes.digraph import DiGraph
# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training
class mx_Graph(DiGraph):
def __init__(self, *args, **kargs):
super(mx_Graph, self).__init__(*args, **kargs)
self.set_msg_func()
self.set_update_func()
self.set_readout_func()
self.init_reprs()
def init_reprs(self, h_init=None):
for n in self.nodes:
self.set_repr(n, h_init)
def set_repr(self, u, h_u, name='h'):
assert u in self.nodes
kwarg = {name: h_u}
self.add_node(u, **kwarg)
def get_repr(self, u, name='h'):
assert u in self.nodes
return self.nodes[u][name]
def set_msg_func(self, func=None, u=None):
"""Function that gathers messages from neighbors"""
def _default_msg_func(u):
assert u in self.nodes
msg_gathered = 0
for v in self.adj[u]:
x = self.get_repr(v)
if x is not None:
msg_gathered += x
return msg_gathered
# TODO: per node message function
# TODO: 'sum' should be a separate function
if func == None:
self.msg_func = _default_msg_func
else:
self.msg_func = func
def set_update_func(self, func=None, u=None):
"""
Update function upon receiving an aggregate
message from a node's neighbor
"""
def _default_update_func(u, m):
if self.nodes[u]['h'] is None:
h_new = m
else:
h_new = self.nodes[u]['h'] + m
self.set_repr(u, h_new)
# TODO: per node update function
if func == None:
self.update_func = _default_update_func
else:
self.update_func = func
def set_readout_func(self, func=None):
"""Readout function of the whole graph"""
def _default_readout_func():
readout = 0
for n in self.nodes:
readout += self.nodes[n]['h']
return readout
if func == None:
self.readout_func = _default_readout_func
else:
self.readout_func = func
def readout(self):
return self.readout_func()
def update_to(self, u):
"""Pull messages from 1-step away neighbors of u"""
assert u in self.nodes
m = self.msg_func(u=u)
self.update_func(u, m)
def update_from(self, u):
"""Update u's 1-step away neighbors"""
assert u in self.nodes
# TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing?
for v in self.adj[u]:
self.update_to(v)
def print_all(self):
for n in self.nodes:
print(n, self.nodes[n])
print()
if __name__ == '__main__':
tg = nx.path_graph(10)
g = mx_Graph(tg)
g.print_all()
tr = nx.balanced_tree(2, 3)
m_tr = mx_Graph(tr)
m_tr.print_all()
g = mx_Graph(nx.path_graph(3))
for n in g:
g.set_repr(n, int(n) + 10)
g.print_all()
print(g.readout())
print("before update:\t", g.nodes[0])
g.update_to(0)
print('after update:\t', g.nodes[0])
g.print_all()
print(g.readout())
g = mx_Graph(nx.bfs_tree(nx.path_graph(3), 0))
g.set_repr(0, 10)
g.print_all()
g.update_from(0)
g.print_all()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment