Commit 2b092811 authored by GaiYu0's avatar GaiYu0
Browse files

line graph

parent d772d390
......@@ -9,231 +9,106 @@ Deviations from paper:
# TODO self-loop?
# TODO in-place edit of node_reprs/edge_reprs in message_func/update_func?
# TODO batch-norm
import copy
import itertools
import dgl.graph as G
import dgl
import dgl.function as fn
import networkx as nx
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class GLGModule(nn.Module):
__SHADOW__ = 'shadow'
class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, radius):
super().__init__()
self.out_feats = out_feats
self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats)
new_linear = lambda: nn.Linear(in_feats, out_feats * 2)
new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
self.theta_x, self.theta_y, self.theta_deg, self.theta_global = \
new_linear(), new_linear(), new_linear(), new_linear()
self.theta_x, self.theta_deg, self.theta_y = \
new_linear(), new_linear(), new_linear()
self.theta_list = new_module_list()
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global = \
new_linear(), new_linear(), new_linear(), new_linear()
self.gamma_y, self.gamma_deg, self.gamma_x = \
new_linear(), new_linear(), new_linear()
self.gamma_list = new_module_list()
@staticmethod
def copy(which):
if which == 'src':
return lambda src, trg, _: src.copy()
elif which == 'trg':
return lambda src, trg, _: trg.copy()
@staticmethod
def aggregate(msg_fld, trg_fld, normalize=False):
def a(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = sum(msg[msg_fld] for msg in edge_reprs)
if normalize:
node_reprs[trg_fld] /= len(edge_reprs)
return node_reprs
return a
@staticmethod
def pull(msg_fld, trg_fld):
def p(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = edge_reprs[0][msg_fld]
return node_reprs
return p
def local_aggregate(self, g):
def step():
g.register_message_func(self.copy('src'), g.edges)
g.register_update_func(self.aggregate('x', 'x'), g.nodes)
g.update_all()
step()
for reprs in g.nodes.values():
reprs[0] = reprs['x']
for i in range(1, self.radius):
for j in range(2 ** (i - 1)):
step()
for reprs in g.nodes.values():
reprs[i] = reprs['x']
@staticmethod
def global_aggregate(g):
shadow = GLGModule.__SHADOW__
copy, aggregate, pull = GLGModule.copy, GLGModule.aggregate, GLGModule.pull
node_list = list(g.nodes)
uv_list = [(node, shadow) for node in g.nodes]
vu_list = [(shadow, node) for node in g.nodes]
g.add_node(shadow) # TODO context manager
tuple(itertools.starmap(g.add_edge, uv_list))
g.register_message_func(copy('src'), uv_list)
g.register_update_func(aggregate('x', 'global', normalize=True), (shadow,))
g.update_to(shadow)
tuple(itertools.starmap(g.add_edge, vu_list))
g.register_message_func(copy('src'), vu_list)
g.register_update_func(pull('global', 'global'), node_list)
g.update_from(shadow)
self.bn_x = nn.BatchNorm1d(out_feats)
self.bn_y = nn.BatchNorm1d(out_feats)
g.remove_node(shadow)
def aggregate(self, g, z):
z_list = []
g.set_n_repr(z)
g.update_all(fn.copy_src(), fn.sum(), batchable=True)
z_list.append(g.get_n_repr())
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_src(), fn.sum(), batchable=True)
z_list.append(g.get_n_repr())
return z_list
@staticmethod
def multiply_by_degree(g):
g.register_message_func(lambda *args: None, g.edges)
def update_func(node_reprs, _):
node_reprs = node_reprs.copy()
node_reprs['deg'] = node_reprs['x'] * node_reprs['degree']
return node_reprs
g.register_update_func(update_func, g.nodes)
g.update_all()
def forward(self, g, lg, x, y, deg_g, deg_lg, eid2nid):
xy = F.embedding(eid2nid, x)
@staticmethod
def message_func(src, trg, _):
return {'y' : src['x']}
def update_func(self, which):
if which == 'node':
linear_x, linear_y, linear_deg, linear_global = \
self.theta_x, self.theta_y, self.theta_deg, self.theta_global
linear_list = self.theta_list
elif which == 'edge':
linear_x, linear_y, linear_deg, linear_global = \
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global
linear_list = self.gamma_list
def u(node_reprs, edge_reprs):
edge_reprs = filter(lambda x: x is not None, edge_reprs)
y = sum(x['y'] for x in edge_reprs)
node_reprs = node_reprs.copy()
node_reprs['x'] = linear_x(node_reprs['x']) \
+ linear_y(y) \
+ linear_deg(node_reprs['deg']) \
+ linear_global(node_reprs['global']) \
+ sum(linear(node_reprs[i]) \
for i, linear in enumerate(linear_list))
return node_reprs
return u
def forward(self, g, lg, glg):
self.local_aggregate(g)
self.local_aggregate(lg)
self.global_aggregate(g)
self.global_aggregate(lg)
self.multiply_by_degree(g)
self.multiply_by_degree(lg)
# TODO efficiency
for node, reprs in g.nodes.items():
glg.nodes[node].update(reprs)
for node, reprs in lg.nodes.items():
glg.nodes[node].update(reprs)
glg.register_message_func(self.message_func, glg.edges)
glg.register_update_func(self.update_func('node'), g.nodes)
glg.register_update_func(self.update_func('edge'), lg.nodes)
glg.update_all()
# TODO efficiency
for node, reprs in g.nodes.items():
reprs.update(glg.nodes[node])
for node, reprs in lg.nodes.items():
reprs.update(glg.nodes[node])
x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
g.set_e_repr(y)
g.update_all(fn.copy_edge(), fn.sum(), batchable=True)
yx = g.get_n_repr()
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
x = self.bn_x(x[:, :self.out_feats] + F.relu(x[:, self.out_feats:]))
y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
lg.set_e_repr(xy)
lg.update_all(fn.copy_edge(), fn.sum(), batchable=True)
xy = lg.get_n_repr()
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, order, radius):
super().__init__()
self.module_list = nn.ModuleList([GLGModule(in_feats, out_feats, radius)
for i in range(order)])
def forward(self, pairs, fusions):
for module, (g, lg), glg in zip(self.module_list, pairs, fusions):
module(g, lg, glg)
for lhs, rhs in zip(pairs[:-1], pairs[1:]):
for node, reprs in lhs[1].nodes.items():
x_rhs = reprs['x']
reprs['x'] = x_rhs + rhs[0].nodes[node]['x']
rhs[0].nodes[node]['x'] += x_rhs
return x, y
class GNN(nn.Module):
def __init__(self, feats, order, radius, n_classes):
super().__init__()
self.order = order
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(in_feats, out_feats, order, radius)
for in_feats, out_feats in zip(feats[:-1], feats[1:])])
@staticmethod
def line_graph(g):
lg = nx.line_graph(g)
glg = nx.DiGraph()
glg.add_nodes_from(g.nodes)
glg.add_nodes_from(lg.nodes)
for u, v in g.edges:
glg.add_edge(u, (u, v))
glg.add_edge((u, v), u)
glg.add_edge(v, (u, v))
glg.add_edge((u, v), v)
return lg, glg
@staticmethod
def nx2dgl(g):
deg_dict = dict(nx.degree(g))
z = sum(deg_dict.values())
dgl_g = G.DGLGraph(g)
for node, reprs in dgl_g.nodes.items():
reprs['degree'] = deg_dict[node]
reprs['x'] = th.full((1, 1), reprs['degree'] / z)
reprs.update(g.nodes[node])
return dgl_g
def forward(self, g):
def __init__(self, g, feats, radius, n_classes):
"""
Parameters
----------
g : networkx.DiGraph
"""
pair_list, glg_list = [], []
dgl_g = self.nx2dgl(g)
origin = dgl_g
for i in range(self.order):
lg, glg = self.line_graph(g)
dgl_lg = self.nx2dgl(lg)
pair_list.append((dgl_g, copy.deepcopy(dgl_lg)))
glg_list.append(G.DGLGraph(glg))
g = lg
dgl_g = dgl_lg
super(GNN, self).__init__()
for module in self.module_list:
module(pair_list, glg_list)
lg = nx.line_graph(g)
x = list(zip(*g.degree))[1]
self.x = self.normalize(th.tensor(x, dtype=th.float).unsqueeze(1))
y = list(zip(*lg.degree))[1]
self.y = self.normalize(th.tensor(y, dtype=th.float).unsqueeze(1))
self.eid2nid = th.tensor([n for [[_, n], _] in lg.edges])
return self.linear(th.cat([reprs['x'] for reprs in origin.nodes.values()], 0))
self.g = dgl.DGLGraph(g)
self.lg = dgl.DGLGraph(nx.convert_node_labels_to_integers(lg))
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(m, n, radius)
for m, n in zip(feats[:-1], feats[1:])])
@staticmethod
def normalize(x):
x = x - th.mean(x, 0)
x = x / th.sqrt(th.mean(x * x, 0))
return x
def cuda(self):
self.x = self.x.cuda()
self.y = self.y.cuda()
self.eid2nid = self.eid2nid.cuda()
super(GNN, self).cuda()
def forward(self):
x, y = self.x, self.y
for module in self.module_list:
x, y = module(self.g, self.lg, x, y, self.x, self.y, self.eid2nid)
return self.linear(x)
"""
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --order 3 --radius 3
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --radius 3
"""
import argparse
import networkx as nx
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gnn
parser = argparse.ArgumentParser()
parser.add_argument('--features', nargs='+', type=int)
parser.add_argument('--gpu', type=int)
parser.add_argument('--n-classes', type=int)
parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-nodes', type=int)
parser.add_argument('--order', type=int)
parser.add_argument('--radius', type=int)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
g = nx.barabasi_albert_graph(args.n_nodes, 1).to_directed() # TODO SBM
y = th.multinomial(th.ones(args.n_classes), args.n_nodes, replacement=True)
network = gnn.GNN(args.features, args.order, args.radius, args.n_classes)
model = gnn.GNN(g, args.features, args.radius, args.n_classes)
if cuda:
network.cuda()
ce = nn.CrossEntropyLoss()
adam = optim.Adam(network.parameters())
model.cuda()
opt = optim.Adam(model.parameters())
for i in range(args.n_iterations):
y_bar = network(g)
loss = ce(y_bar, y)
adam.zero_grad()
y_bar = model()
loss = F.cross_entropy(y_bar, y)
opt.zero_grad()
loss.backward()
adam.step()
opt.step()
print('[iteration %d]loss %f' % (i, loss))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment