"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "e589bdb956c9be33fc73e1d4614d8d1c1ad95544"
Commit 2ed3989c authored by zzhang-cn's avatar zzhang-cn
Browse files

git repo reorg

parent 75a43e4a
import networkx as nx
import torch as T
import torch.nn as NN
from util import *
from networkx.classes.digraph import DiGraph
class DiGraph(NN.Module):
'''
Reserved attributes:
* state: node state vectors during message passing iterations
edges does not have "state vectors"; the "state" field is reserved for storing messages
* tag: node-/edge-specific feature tensors or other data
'''
def __init__(self, graph):
NN.Module.__init__(self)
self.G = graph
self.message_funcs = []
self.update_funcs = []
'''
Defult modules: this is Pytorch specific
- MessageModule: copy
- UpdateModule: vanilla RNN
- ReadoutModule: bag of words
- ReductionModule: bag of words
'''
import torch as th
import torch.nn as nn
import torch.nn.functional as F
def _nodes_or_all(self, nodes='all'):
return self.G.nodes() if nodes == 'all' else nodes
class DefaultMessageModule(nn.Module):
"""
Default message module:
- copy
"""
def __init__(self, *args, **kwargs):
super(DefaultMessageModule, self).__init__(*args, **kwargs)
def _edges_or_all(self, edges='all'):
return self.G.edges() if edges == 'all' else edges
def forward(self, x):
return x
def _node_tag_name(self, v):
return '(%s)' % v
class DefaultUpdateModule(nn.Module):
"""
Default update module:
- a vanilla GRU with ReLU, or GRU
"""
def __init__(self, *args, **kwargs):
super(DefaultUpdateModule, self).__init__()
h_dims = self.h_dims = kwargs.get('h_dims', 128)
net_type = self.net_type = kwargs.get('net_type', 'fwd')
n_func = self.n_func = kwargs.get('n_func', 1)
self.f_idx = 0
self.reduce_func = DefaultReductionModule()
if net_type == 'gru':
self.net = [nn.GRUCell(h_dims, h_dims) for i in range(n_func)]
else:
self.net = [nn.Linear(2 * h_dims, h_dims) for i in range(n_func)]
def _edge_tag_name(self, u, v):
return '(%s, %s)' % (min(u, v), max(u, v))
def forward(self, x, msgs):
if not th.is_tensor(x):
x = th.zeros_like(msgs[0])
m = self.reduce_func(msgs)
assert(self.f_idx < self.n_func)
if self.net_type == 'gru':
out = self.net[self.f_idx](m, x)
else:
_in = th.cat((m, x), 1)
out = F.relu(self.net[self.f_idx](_in))
self.f_idx += 1
return out
def zero_node_state(self, state_dims, batch_size=None, nodes='all'):
shape = (
[batch_size] + list(state_dims)
if batch_size is not None
else state_dims
)
nodes = self._nodes_or_all(nodes)
def reset_f_idx(self):
self.f_idx = 0
for v in nodes:
self.G.node[v]['state'] = tovar(T.zeros(shape))
class DefaultReductionModule(nn.Module):
"""
Default readout:
- bag of words
"""
def __init__(self, *args, **kwargs):
super(DefaultReductionModule, self).__init__(*args, **kwargs)
def init_node_tag_with(self, shape, init_func, dtype=T.float32, nodes='all', args=()):
nodes = self._nodes_or_all(nodes)
for v in nodes:
self.G.node[v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._node_tag_name(v), self.G.node[v]['tag'])
def init_edge_tag_with(self, shape, init_func, dtype=T.float32, edges='all', args=()):
edges = self._edges_or_all(edges)
for u, v in edges:
self.G[u][v]['tag'] = init_func(NN.Parameter(tovar(T.zeros(shape, dtype=dtype))), *args)
self.register_parameter(self._edge_tag_name(u, v), self.G[u][v]['tag'])
def remove_node_tag(self, nodes='all'):
nodes = self._nodes_or_all(nodes)
def forward(self, x_s):
out = th.stack(x_s)
out = th.sum(out, dim=0)
return out
for v in nodes:
delattr(self, self._node_tag_name(v))
del self.G.node[v]['tag']
class DefaultReadoutModule(nn.Module):
"""
Default readout:
- bag of words
"""
def __init__(self, *args, **kwargs):
super(DefaultReadoutModule, self).__init__(*args, **kwargs)
self.reduce_func = DefaultReductionModule()
def remove_edge_tag(self, edges='all'):
edges = self._edges_or_all(edges)
def forward(self, x_s):
return self.reduce_func(x_s)
for u, v in edges:
delattr(self, self._edge_tag_name(u, v))
del self.G[u][v]['tag']
@property
def node(self):
return self.G.node
@property
def edges(self):
return self.G.edges
def edge_tags(self):
for u, v in self.G.edges():
yield self.G[u][v]['tag']
class mx_Graph(DiGraph):
'''
Functions:
- m_func: per edge (u, v), default is u['state']
- u_func: per node u, default is RNN(m, u['state'])
'''
def __init__(self, *args, **kargs):
super(mx_Graph, self).__init__(*args, **kargs)
self.m_func = DefaultMessageModule()
self.u_func = DefaultUpdateModule()
self.readout_func = DefaultReadoutModule()
self.init_reprs()
def node_tags(self):
for v in self.G.nodes():
yield self.G.node[v]['tag']
def init_reprs(self, h_init=None):
for n in self.nodes:
self.set_repr(n, h_init)
def states(self):
for v in self.G.nodes():
yield self.G.node[v]['state']
def set_repr(self, u, h_u, name='state'):
assert u in self.nodes
kwarg = {name: h_u}
self.add_node(u, **kwarg)
def named_edge_tags(self):
for u, v in self.G.edges():
yield ((u, v), self.G[u][v]['tag'])
def get_repr(self, u, name='state'):
assert u in self.nodes
return self.nodes[u][name]
def named_node_tags(self):
for v in self.G.nodes():
yield (v, self.G.node[v]['tag'])
def _nodes_or_all(self, nodes='all'):
return self.nodes() if nodes == 'all' else nodes
def named_states(self):
for v in self.G.nodes():
yield (v, self.G.node[v]['state'])
def _edges_or_all(self, edges='all'):
return self.edges() if edges == 'all' else edges
def register_message_func(self, message_func, edges='all', batched=False):
'''
......@@ -106,7 +120,11 @@ class DiGraph(NN.Module):
message function: accepts source state tensor and edge tag tensor, and
returns a message tensor
'''
self.message_funcs.append((self._edges_or_all(edges), message_func, batched))
if edges == 'all':
self.m_func = message_func
else:
for e in self.edges:
self.edges[e]['m_func'] = message_func
def register_update_func(self, update_func, nodes='all', batched=False):
'''
......@@ -114,38 +132,132 @@ class DiGraph(NN.Module):
update function: accepts a node attribute dictionary (including state and tag),
and a list of tuples (source node, target node, edge attribute dictionary)
'''
self.update_funcs.append((self._nodes_or_all(nodes), update_func, batched))
if nodes == 'all':
self.u_func = update_func
else:
for n in nodes:
self.node[n]['u_func'] = update_func
def register_readout_func(self, readout_func):
self.readout_func = readout_func
def readout(self, nodes='all', **kwargs):
nodes_state = []
nodes = self._nodes_or_all(nodes)
for n in nodes:
nodes_state.append(self.get_repr(n))
return self.readout_func(nodes_state, **kwargs)
def sendto(self, u, v):
"""Compute message on edge u->v
Args:
u: source node
v: destination node
"""
f_msg = self.edges[(u, v)].get('m_func', self.m_func)
m = f_msg(self.get_repr(u))
self.edges[(u, v)]['msg'] = m
def sendto_ebunch(self, ebunch):
"""Compute message on edge u->v
Args:
ebunch: a bunch of edges
"""
#TODO: simplify the logics
for u, v in ebunch:
f_msg = self.edges[(u, v)].get('m_func', self.m_func)
m = f_msg(self.get_repr(u))
self.edges[(u, v)]['msg'] = m
def recvfrom(self, u, nodes):
"""Update u by nodes
Args:
u: node to be updated
nodes: nodes with pre-computed messages to u
"""
m = [self.edges[(v, u)]['msg'] for v in nodes]
f_update = self.nodes[u].get('u_func', self.u_func)
x_new = f_update(self.get_repr(u), m)
self.set_repr(u, x_new)
def update_by_edge(self, e):
u, v = e
self.sendto(u, v)
self.recvfrom(v, [u])
def update_to(self, u):
"""Pull messages from 1-step away neighbors of u"""
assert u in self.nodes
for v in self.pred[u]:
self.sendto(v, u)
self.recvfrom(u, list(self.pred[u]))
def update_from(self, u):
"""Update u's 1-step away neighbors"""
assert u in self.nodes
for v in self.succ[u]:
self.update_to(v)
def update_all_step(self):
self.sendto_ebunch(self.edges)
for u in self.nodes:
self.recvfrom(u, list(self.pred[u]))
def draw(self):
from networkx.drawing.nx_agraph import graphviz_layout
pos = graphviz_layout(self.G, prog='dot')
nx.draw(self.G, pos, with_labels=True)
def step(self):
# update message
for ebunch, f, batched in self.message_funcs:
if batched:
# FIXME: need to optimize since we are repeatedly stacking and
# unpacking
source = T.stack([self.G.node[u]['state'] for u, _ in ebunch])
edge_tags = [self.G[u][v].get('tag', None) for u, v in ebunch]
if all(t is None for t in edge_tags):
edge_tag = None
else:
edge_tag = T.stack([self.G[u][v]['tag'] for u, v in ebunch])
message = f(source, edge_tag)
for i, (u, v) in enumerate(ebunch):
self.G[u][v]['state'] = message[i]
else:
for u, v in ebunch:
self.G[u][v]['state'] = f(
self.G.node[u]['state'],
self.G[u][v]['tag']
)
# update state
# TODO: does it make sense to batch update the nodes?
for vbunch, f, batched in self.update_funcs:
for v in vbunch:
self.G.node[v]['state'] = f(self.G.node[v], list(self.G.in_edges(v, data=True)))
pos = graphviz_layout(self, prog='dot')
nx.draw(self, pos, with_labels=True)
def set_reduction_func(self):
def _default_reduction_func(x_s):
out = th.stack(x_s)
out = th.sum(out, dim=0)
return out
self._reduction_func = _default_reduction_func
def set_gather_func(self, u=None):
pass
def print_all(self):
for n in self.nodes:
print(n, self.nodes[n])
print()
if __name__ == '__main__':
from torch.autograd import Variable as Var
th.random.manual_seed(0)
print("testing vanilla RNN update")
g_path = mx_Graph(nx.path_graph(2))
g_path.set_repr(0, th.rand(2, 128))
g_path.sendto(0, 1)
g_path.recvfrom(1, [0])
g_path.readout()
'''
# this makes a uni-edge tree
tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
m_tr = mx_Graph(tr)
m_tr.print_all()
'''
print("testing GRU update")
g = mx_Graph(nx.path_graph(3))
update_net = DefaultUpdateModule(h_dims=4, net_type='gru')
g.register_update_func(update_net)
msg_net = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
g.register_message_func(msg_net)
for n in g:
g.set_repr(n, th.rand(2, 4))
y_pre = g.readout()
g.update_from(0)
y_after = g.readout()
upd_nets = DefaultUpdateModule(h_dims=4, net_type='gru', n_func=2)
g.register_update_func(upd_nets)
g.update_from(0)
g.update_from(0)
import networkx as nx
#from networkx.classes.graph import Graph
from networkx.classes.digraph import DiGraph
import torch as th
#import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable as Var
# TODO: loss functions and training
class mx_Graph(DiGraph):
def __init__(self, *args, **kargs):
super(mx_Graph, self).__init__(*args, **kargs)
self.set_msg_func()
self.set_gather_func()
self.set_reduction_func()
self.set_update_func()
self.set_readout_func()
self.init_reprs()
def init_reprs(self, h_init=None):
for n in self.nodes:
self.set_repr(n, h_init)
def set_repr(self, u, h_u, name='h'):
assert u in self.nodes
kwarg = {name: h_u}
self.add_node(u, **kwarg)
def get_repr(self, u, name='h'):
assert u in self.nodes
return self.nodes[u][name]
def set_reduction_func(self):
def _default_reduction_func(x_s):
out = th.stack(x_s)
out = th.sum(out, dim=0)
return out
self._reduction_func = _default_reduction_func
def set_gather_func(self, u=None):
pass
def set_msg_func(self, func=None, u=None):
"""Function that gathers messages from neighbors"""
def _default_msg_func(u):
assert u in self.nodes
msg_gathered = []
for v in self.pred[u]:
x = self.get_repr(v)
if x is not None:
msg_gathered.append(x)
return self._reduction_func(msg_gathered)
# TODO: per node message function
# TODO: 'sum' should be a separate function
if func == None:
self._msg_func = _default_msg_func
else:
self._msg_func = func
def set_update_func(self, func=None, u=None):
"""
Update function upon receiving an aggregate
message from a node's neighbor
"""
def _default_update_func(x, m):
return x + m
# TODO: per node update function
if func == None:
self._update_func = _default_update_func
else:
self._update_func = func
def set_readout_func(self, func=None):
"""Readout function of the whole graph"""
def _default_readout_func():
valid_hs = []
for x in self.nodes:
h = self.get_repr(x)
if h is not None:
valid_hs.append(h)
return self._reduction_func(valid_hs)
#
if func == None:
self.readout_func = _default_readout_func
else:
self.readout_func = func
def readout(self):
return self.readout_func()
def update_to(self, u):
"""Pull messages from 1-step away neighbors of u"""
assert u in self.nodes
m = self._msg_func(u=u)
x = self.get_repr(u)
# TODO: ugly hack
if x is None:
y = self._update_func(m)
else:
y = self._update_func(x, m)
self.set_repr(u, y)
def update_from(self, u):
"""Update u's 1-step away neighbors"""
assert u in self.nodes
# TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing?
for v in self.succ[u]:
self.update_to(v)
def print_all(self):
for n in self.nodes:
print(n, self.nodes[n])
print()
if __name__ == '__main__':
th.random.manual_seed(0)
''': this makes a digraph with double edges
tg = nx.path_graph(10)
g = mx_Graph(tg)
g.print_all()
# this makes a uni-edge tree
tr = nx.bfs_tree(nx.balanced_tree(2, 3), 0)
m_tr = mx_Graph(tr)
m_tr.print_all()
'''
print("testing GRU update")
g = mx_Graph(nx.path_graph(3))
g.set_update_func(nn.GRUCell(4, 4))
for n in g:
g.set_repr(n, Var(th.rand(2, 4)))
print("\t**before:"); g.print_all()
g.update_from(0)
g.update_from(1)
print("\t**after:"); g.print_all()
print("\ntesting fwd update")
g.clear()
g.add_path([0, 1, 2])
g.init_reprs()
fwd_net = nn.Sequential(nn.Linear(4, 4), nn.ReLU())
g.set_update_func(fwd_net)
g.set_repr(0, Var(th.rand(2, 4)))
print("\t**before:"); g.print_all()
g.update_from(0)
g.update_from(1)
print("\t**after:"); g.print_all()
import networkx as nx
from networkx.classes.graph import Graph
from networkx.classes.digraph import DiGraph
# TODO: make representation numpy/tensor from pytorch
# TODO: make message/update functions pytorch functions
# TODO: loss functions and training
class mx_Graph(DiGraph):
def __init__(self, *args, **kargs):
super(mx_Graph, self).__init__(*args, **kargs)
self.set_msg_func()
self.set_update_func()
self.set_readout_func()
self.init_reprs()
def init_reprs(self, h_init=None):
for n in self.nodes:
self.set_repr(n, h_init)
def set_repr(self, u, h_u, name='h'):
assert u in self.nodes
kwarg = {name: h_u}
self.add_node(u, **kwarg)
def get_repr(self, u, name='h'):
assert u in self.nodes
return self.nodes[u][name]
def set_msg_func(self, func=None, u=None):
"""Function that gathers messages from neighbors"""
def _default_msg_func(u):
assert u in self.nodes
msg_gathered = 0
for v in self.adj[u]:
x = self.get_repr(v)
if x is not None:
msg_gathered += x
return msg_gathered
# TODO: per node message function
# TODO: 'sum' should be a separate function
if func == None:
self.msg_func = _default_msg_func
else:
self.msg_func = func
def set_update_func(self, func=None, u=None):
"""
Update function upon receiving an aggregate
message from a node's neighbor
"""
def _default_update_func(u, m):
if self.nodes[u]['h'] is None:
h_new = m
else:
h_new = self.nodes[u]['h'] + m
self.set_repr(u, h_new)
# TODO: per node update function
if func == None:
self.update_func = _default_update_func
else:
self.update_func = func
def set_readout_func(self, func=None):
"""Readout function of the whole graph"""
def _default_readout_func():
readout = 0
for n in self.nodes:
readout += self.nodes[n]['h']
return readout
if func == None:
self.readout_func = _default_readout_func
else:
self.readout_func = func
def readout(self):
return self.readout_func()
def update_to(self, u):
"""Pull messages from 1-step away neighbors of u"""
assert u in self.nodes
m = self.msg_func(u=u)
self.update_func(u, m)
def update_from(self, u):
"""Update u's 1-step away neighbors"""
assert u in self.nodes
# TODO: this asks v to pull from nodes other than
# TODO: u, is this a good thing?
for v in self.adj[u]:
self.update_to(v)
def print_all(self):
for n in self.nodes:
print(n, self.nodes[n])
print()
if __name__ == '__main__':
tg = nx.path_graph(10)
g = mx_Graph(tg)
g.print_all()
tr = nx.balanced_tree(2, 3)
m_tr = mx_Graph(tr)
m_tr.print_all()
g = mx_Graph(nx.path_graph(3))
for n in g:
g.set_repr(n, int(n) + 10)
g.print_all()
print(g.readout())
print("before update:\t", g.nodes[0])
g.update_to(0)
print('after update:\t', g.nodes[0])
g.print_all()
print(g.readout())
g = mx_Graph(nx.bfs_tree(nx.path_graph(3), 0))
g.set_repr(0, 10)
g.print_all()
g.update_from(0)
g.print_all()
import networkx as nx
import mx
tr = nx.balanced_tree(3, 4)
tr = mx.mx_Graph(tr)
# now we have made a skeleton tree, we can register
# representations to a subset of the nodes, and then
# update them
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment