Commit 62a39e5b authored by zzhang-cn's avatar zzhang-cn
Browse files

subclass of networkx

parent 65590977
import networkx as nx
from networkx.classes.graph import Graph
# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training
class mx_Graph(Graph):
def __init__(self, *args, **kargs):
super(mx_Graph, self).__init__(*args, **kargs)
self.set_msg_func()
self.set_update_func()
self.set_readout_func()
self.init_reprs()
def init_reprs(self, h_init=None):
for n in self.nodes:
self.set_repr(n, h_init)
def set_repr(self, u, h_u, name=None):
assert u in self.nodes
if name == None:
self.add_node(u, h=h_u)
else:
self.add_node(u, name=h_u)
def get_repr(self, u, name=None):
assert u in self.nodes
if name == None:
return self.nodes[u]['h']
else:
return self.nodes[u][name]
def set_msg_func(self, func=None, u=None):
"""Function that gathers messages from neighbors"""
def _default_msg_func(u):
assert u in self.nodes
msg_gathered = 0
for v in self.adj[u]:
x = self.get_repr(v)
if x is not None:
msg_gathered += x
return msg_gathered
# TODO: per node message function
# TODO: 'sum' should be a separate function
if func == None:
self.msg_func = _default_msg_func
else:
self.msg_func = func
def set_update_func(self, func=None, u=None):
"""
Update function upon receiving an aggregate
message from a node's neighbor
"""
def _default_update_func(u, m):
h_new = self.nodes[u]['h'] + m
self.set_repr(u, h_new)
# TODO: per node update function
if func == None:
self.update_func = _default_update_func
else:
self.update_func = func
def set_readout_func(self, func=None):
"""Readout function of the whole graph"""
def _default_readout_func():
readout = 0
for n in self.nodes:
readout += self.nodes[n]['h']
return readout
if func == None:
self.readout_func = _default_readout_func
else:
self.readout_func = func
def readout(self):
return self.readout_func()
def update_to(self, u):
"""Pull messages from 1-step away neighbors of u"""
assert u in self.nodes
m = self.msg_func(u=u)
self.update_func(u, m)
def update_from(self, u):
"""Update u's 1-step away neighbors"""
assert u in self.nodes
# TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing?
for v in self.adj[u]:
self.update_to(v)
def print_all(self):
for n in self.nodes:
print(n, self.nodes[n])
print()
if __name__ == '__main__':
tg = nx.path_graph(10)
g = mx_Graph(tg)
g.print_all()
tr = nx.balanced_tree(2, 3)
m_tr = mx_Graph(tr)
m_tr.print_all()
g = mx_Graph(nx.path_graph(3))
for n in g:
g.set_repr(n, int(n) + 10)
g.print_all()
print(g.readout())
print("before update:\t", g.nodes[0])
g.update_to(0)
print('after update:\t', g.nodes[0])
g.print_all()
print(g.readout())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment