"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "84e5cc596c67afd2863131f5fa309ebaa994624d"
Commit 2b092811 authored by GaiYu0's avatar GaiYu0
Browse files

line graph

parent d772d390
...@@ -9,231 +9,106 @@ Deviations from paper: ...@@ -9,231 +9,106 @@ Deviations from paper:
# TODO self-loop? # TODO self-loop?
# TODO in-place edit of node_reprs/edge_reprs in message_func/update_func?
# TODO batch-norm
import copy import copy
import itertools import itertools
import dgl.graph as G import dgl
import dgl.function as fn
import networkx as nx import networkx as nx
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
class GLGModule(nn.Module): class GNNModule(nn.Module):
__SHADOW__ = 'shadow'
def __init__(self, in_feats, out_feats, radius): def __init__(self, in_feats, out_feats, radius):
super().__init__() super().__init__()
self.out_feats = out_feats
self.radius = radius self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats) new_linear = lambda: nn.Linear(in_feats, out_feats * 2)
new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)]) new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
self.theta_x, self.theta_y, self.theta_deg, self.theta_global = \ self.theta_x, self.theta_deg, self.theta_y = \
new_linear(), new_linear(), new_linear(), new_linear() new_linear(), new_linear(), new_linear()
self.theta_list = new_module_list() self.theta_list = new_module_list()
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global = \ self.gamma_y, self.gamma_deg, self.gamma_x = \
new_linear(), new_linear(), new_linear(), new_linear() new_linear(), new_linear(), new_linear()
self.gamma_list = new_module_list() self.gamma_list = new_module_list()
@staticmethod self.bn_x = nn.BatchNorm1d(out_feats)
def copy(which): self.bn_y = nn.BatchNorm1d(out_feats)
if which == 'src':
return lambda src, trg, _: src.copy()
elif which == 'trg':
return lambda src, trg, _: trg.copy()
@staticmethod
def aggregate(msg_fld, trg_fld, normalize=False):
def a(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = sum(msg[msg_fld] for msg in edge_reprs)
if normalize:
node_reprs[trg_fld] /= len(edge_reprs)
return node_reprs
return a
@staticmethod
def pull(msg_fld, trg_fld):
def p(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = edge_reprs[0][msg_fld]
return node_reprs
return p
def local_aggregate(self, g):
def step():
g.register_message_func(self.copy('src'), g.edges)
g.register_update_func(self.aggregate('x', 'x'), g.nodes)
g.update_all()
step()
for reprs in g.nodes.values():
reprs[0] = reprs['x']
for i in range(1, self.radius):
for j in range(2 ** (i - 1)):
step()
for reprs in g.nodes.values():
reprs[i] = reprs['x']
@staticmethod
def global_aggregate(g):
shadow = GLGModule.__SHADOW__
copy, aggregate, pull = GLGModule.copy, GLGModule.aggregate, GLGModule.pull
node_list = list(g.nodes)
uv_list = [(node, shadow) for node in g.nodes]
vu_list = [(shadow, node) for node in g.nodes]
g.add_node(shadow) # TODO context manager
tuple(itertools.starmap(g.add_edge, uv_list))
g.register_message_func(copy('src'), uv_list)
g.register_update_func(aggregate('x', 'global', normalize=True), (shadow,))
g.update_to(shadow)
tuple(itertools.starmap(g.add_edge, vu_list))
g.register_message_func(copy('src'), vu_list)
g.register_update_func(pull('global', 'global'), node_list)
g.update_from(shadow)
g.remove_node(shadow) def aggregate(self, g, z):
z_list = []
g.set_n_repr(z)
g.update_all(fn.copy_src(), fn.sum(), batchable=True)
z_list.append(g.get_n_repr())
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_src(), fn.sum(), batchable=True)
z_list.append(g.get_n_repr())
return z_list
@staticmethod def forward(self, g, lg, x, y, deg_g, deg_lg, eid2nid):
def multiply_by_degree(g): xy = F.embedding(eid2nid, x)
g.register_message_func(lambda *args: None, g.edges)
def update_func(node_reprs, _):
node_reprs = node_reprs.copy()
node_reprs['deg'] = node_reprs['x'] * node_reprs['degree']
return node_reprs
g.register_update_func(update_func, g.nodes)
g.update_all()
@staticmethod x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
def message_func(src, trg, _): g.set_e_repr(y)
return {'y' : src['x']} g.update_all(fn.copy_edge(), fn.sum(), batchable=True)
yx = g.get_n_repr()
def update_func(self, which): x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
if which == 'node': x = self.bn_x(x[:, :self.out_feats] + F.relu(x[:, self.out_feats:]))
linear_x, linear_y, linear_deg, linear_global = \
self.theta_x, self.theta_y, self.theta_deg, self.theta_global
linear_list = self.theta_list
elif which == 'edge':
linear_x, linear_y, linear_deg, linear_global = \
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global
linear_list = self.gamma_list
def u(node_reprs, edge_reprs):
edge_reprs = filter(lambda x: x is not None, edge_reprs)
y = sum(x['y'] for x in edge_reprs)
node_reprs = node_reprs.copy()
node_reprs['x'] = linear_x(node_reprs['x']) \
+ linear_y(y) \
+ linear_deg(node_reprs['deg']) \
+ linear_global(node_reprs['global']) \
+ sum(linear(node_reprs[i]) \
for i, linear in enumerate(linear_list))
return node_reprs
return u
def forward(self, g, lg, glg):
self.local_aggregate(g)
self.local_aggregate(lg)
self.global_aggregate(g)
self.global_aggregate(lg)
self.multiply_by_degree(g)
self.multiply_by_degree(lg)
# TODO efficiency
for node, reprs in g.nodes.items():
glg.nodes[node].update(reprs)
for node, reprs in lg.nodes.items():
glg.nodes[node].update(reprs)
glg.register_message_func(self.message_func, glg.edges)
glg.register_update_func(self.update_func('node'), g.nodes)
glg.register_update_func(self.update_func('edge'), lg.nodes)
glg.update_all()
# TODO efficiency
for node, reprs in g.nodes.items():
reprs.update(glg.nodes[node])
for node, reprs in lg.nodes.items():
reprs.update(glg.nodes[node])
y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
lg.set_e_repr(xy)
lg.update_all(fn.copy_edge(), fn.sum(), batchable=True)
xy = lg.get_n_repr()
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
class GNNModule(nn.Module): return x, y
def __init__(self, in_feats, out_feats, order, radius):
super().__init__()
self.module_list = nn.ModuleList([GLGModule(in_feats, out_feats, radius)
for i in range(order)])
def forward(self, pairs, fusions):
for module, (g, lg), glg in zip(self.module_list, pairs, fusions):
module(g, lg, glg)
for lhs, rhs in zip(pairs[:-1], pairs[1:]):
for node, reprs in lhs[1].nodes.items():
x_rhs = reprs['x']
reprs['x'] = x_rhs + rhs[0].nodes[node]['x']
rhs[0].nodes[node]['x'] += x_rhs
class GNN(nn.Module): class GNN(nn.Module):
def __init__(self, feats, order, radius, n_classes): def __init__(self, g, feats, radius, n_classes):
super().__init__()
self.order = order
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(in_feats, out_feats, order, radius)
for in_feats, out_feats in zip(feats[:-1], feats[1:])])
@staticmethod
def line_graph(g):
lg = nx.line_graph(g)
glg = nx.DiGraph()
glg.add_nodes_from(g.nodes)
glg.add_nodes_from(lg.nodes)
for u, v in g.edges:
glg.add_edge(u, (u, v))
glg.add_edge((u, v), u)
glg.add_edge(v, (u, v))
glg.add_edge((u, v), v)
return lg, glg
@staticmethod
def nx2dgl(g):
deg_dict = dict(nx.degree(g))
z = sum(deg_dict.values())
dgl_g = G.DGLGraph(g)
for node, reprs in dgl_g.nodes.items():
reprs['degree'] = deg_dict[node]
reprs['x'] = th.full((1, 1), reprs['degree'] / z)
reprs.update(g.nodes[node])
return dgl_g
def forward(self, g):
""" """
Parameters Parameters
---------- ----------
g : networkx.DiGraph g : networkx.DiGraph
""" """
pair_list, glg_list = [], [] super(GNN, self).__init__()
dgl_g = self.nx2dgl(g)
origin = dgl_g
for i in range(self.order):
lg, glg = self.line_graph(g)
dgl_lg = self.nx2dgl(lg)
pair_list.append((dgl_g, copy.deepcopy(dgl_lg)))
glg_list.append(G.DGLGraph(glg))
g = lg
dgl_g = dgl_lg
for module in self.module_list: lg = nx.line_graph(g)
module(pair_list, glg_list) x = list(zip(*g.degree))[1]
self.x = self.normalize(th.tensor(x, dtype=th.float).unsqueeze(1))
y = list(zip(*lg.degree))[1]
self.y = self.normalize(th.tensor(y, dtype=th.float).unsqueeze(1))
self.eid2nid = th.tensor([n for [[_, n], _] in lg.edges])
return self.linear(th.cat([reprs['x'] for reprs in origin.nodes.values()], 0)) self.g = dgl.DGLGraph(g)
self.lg = dgl.DGLGraph(nx.convert_node_labels_to_integers(lg))
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(m, n, radius)
for m, n in zip(feats[:-1], feats[1:])])
@staticmethod
def normalize(x):
x = x - th.mean(x, 0)
x = x / th.sqrt(th.mean(x * x, 0))
return x
def cuda(self):
self.x = self.x.cuda()
self.y = self.y.cuda()
self.eid2nid = self.eid2nid.cuda()
super(GNN, self).cuda()
def forward(self):
x, y = self.x, self.y
for module in self.module_list:
x, y = module(self.g, self.lg, x, y, self.x, self.y, self.eid2nid)
return self.linear(x)
""" """
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --order 3 --radius 3 ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --radius 3
""" """
import argparse import argparse
import networkx as nx import networkx as nx
import torch as th import torch as th
import torch.nn as nn import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import gnn import gnn
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--features', nargs='+', type=int) parser.add_argument('--features', nargs='+', type=int)
parser.add_argument('--gpu', type=int) parser.add_argument('--gpu', type=int)
parser.add_argument('--n-classes', type=int) parser.add_argument('--n-classes', type=int)
parser.add_argument('--n-iterations', type=int) parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-nodes', type=int) parser.add_argument('--n-nodes', type=int)
parser.add_argument('--order', type=int)
parser.add_argument('--radius', type=int) parser.add_argument('--radius', type=int)
args = parser.parse_args() args = parser.parse_args()
if args.gpu < 0: if args.gpu < 0:
cuda = False cuda = False
else: else:
cuda = True cuda = True
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
g = nx.barabasi_albert_graph(args.n_nodes, 1).to_directed() # TODO SBM g = nx.barabasi_albert_graph(args.n_nodes, 1).to_directed() # TODO SBM
y = th.multinomial(th.ones(args.n_classes), args.n_nodes, replacement=True) y = th.multinomial(th.ones(args.n_classes), args.n_nodes, replacement=True)
model = gnn.GNN(g, args.features, args.radius, args.n_classes)
network = gnn.GNN(args.features, args.order, args.radius, args.n_classes)
if cuda: if cuda:
network.cuda() model.cuda()
ce = nn.CrossEntropyLoss() opt = optim.Adam(model.parameters())
adam = optim.Adam(network.parameters())
for i in range(args.n_iterations): for i in range(args.n_iterations):
y_bar = network(g) y_bar = model()
loss = ce(y_bar, y) loss = F.cross_entropy(y_bar, y)
adam.zero_grad() opt.zero_grad()
loss.backward() loss.backward()
adam.step() opt.step()
print('[iteration %d]loss %f' % (i, loss)) print('[iteration %d]loss %f' % (i, loss))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment