Commit b9073209 authored by GaiYu0's avatar GaiYu0 Committed by Minjie Wang
Browse files

CDGNN example & minor fixes in dgl/graph.py (#11)

* cdgnn example & minor fixes in dgl/graph.py

* incomplete TreeLSTM

* add data preprocess of SST

* fix nx_SST.py

* TreeLSTM completed
parent c87564d5
import torch.utils as utils
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415
Deviations from paper:
- Addition of global aggregation operator.
- Message passing is equivalent to `A^j \cdot X`, instead of `\min(1, A^j) \cdot X`.
"""
# TODO self-loop?
# TODO in-place edit of node_reprs/edge_reprs in message_func/update_func?
# TODO batch-norm
import copy
import itertools
import dgl.graph as G
import networkx as nx
import torch as th
import torch.nn as nn
class GLGModule(nn.Module):
__SHADOW__ = 'shadow'
def __init__(self, in_feats, out_feats, radius):
super().__init__()
self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats)
new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
self.theta_x, self.theta_y, self.theta_deg, self.theta_global = \
new_linear(), new_linear(), new_linear(), new_linear()
self.theta_list = new_module_list()
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global = \
new_linear(), new_linear(), new_linear(), new_linear()
self.gamma_list = new_module_list()
@staticmethod
def copy(which):
if which == 'src':
return lambda src, trg, _: src.copy()
elif which == 'trg':
return lambda src, trg, _: trg.copy()
@staticmethod
def aggregate(msg_fld, trg_fld, normalize=False):
def a(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = sum(msg[msg_fld] for msg in edge_reprs)
if normalize:
node_reprs[trg_fld] /= len(edge_reprs)
return node_reprs
return a
@staticmethod
def pull(msg_fld, trg_fld):
def p(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = edge_reprs[0][msg_fld]
return node_reprs
return p
def local_aggregate(self, g):
def step():
g.register_message_func(self.copy('src'), g.edges)
g.register_update_func(self.aggregate('x', 'x'), g.nodes)
g.update_all()
step()
for reprs in g.nodes.values():
reprs[0] = reprs['x']
for i in range(1, self.radius):
for j in range(2 ** (i - 1)):
step()
for reprs in g.nodes.values():
reprs[i] = reprs['x']
@staticmethod
def global_aggregate(g):
shadow = GLGModule.__SHADOW__
copy, aggregate, pull = GLGModule.copy, GLGModule.aggregate, GLGModule.pull
node_list = list(g.nodes)
uv_list = [(node, shadow) for node in g.nodes]
vu_list = [(shadow, node) for node in g.nodes]
g.add_node(shadow) # TODO context manager
tuple(itertools.starmap(g.add_edge, uv_list))
g.register_message_func(copy('src'), uv_list)
g.register_update_func(aggregate('x', 'global', normalize=True), (shadow,))
g.update_to(shadow)
tuple(itertools.starmap(g.add_edge, vu_list))
g.register_message_func(copy('src'), vu_list)
g.register_update_func(pull('global', 'global'), node_list)
g.update_from(shadow)
g.remove_node(shadow)
@staticmethod
def multiply_by_degree(g):
g.register_message_func(lambda *args: None, g.edges)
def update_func(node_reprs, _):
node_reprs = node_reprs.copy()
node_reprs['deg'] = node_reprs['x'] * node_reprs['degree']
return node_reprs
g.register_update_func(update_func, g.nodes)
g.update_all()
@staticmethod
def message_func(src, trg, _):
return {'y' : src['x']}
def update_func(self, which):
if which == 'node':
linear_x, linear_y, linear_deg, linear_global = \
self.theta_x, self.theta_y, self.theta_deg, self.theta_global
linear_list = self.theta_list
elif which == 'edge':
linear_x, linear_y, linear_deg, linear_global = \
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global
linear_list = self.gamma_list
def u(node_reprs, edge_reprs):
edge_reprs = filter(lambda x: x is not None, edge_reprs)
y = sum(x['y'] for x in edge_reprs)
node_reprs = node_reprs.copy()
node_reprs['x'] = linear_x(node_reprs['x']) \
+ linear_y(y) \
+ linear_deg(node_reprs['deg']) \
+ linear_global(node_reprs['global']) \
+ sum(linear(node_reprs[i]) \
for i, linear in enumerate(linear_list))
return node_reprs
return u
def forward(self, g, lg, glg):
self.local_aggregate(g)
self.local_aggregate(lg)
self.global_aggregate(g)
self.global_aggregate(lg)
self.multiply_by_degree(g)
self.multiply_by_degree(lg)
# TODO efficiency
for node, reprs in g.nodes.items():
glg.nodes[node].update(reprs)
for node, reprs in lg.nodes.items():
glg.nodes[node].update(reprs)
glg.register_message_func(self.message_func, glg.edges)
glg.register_update_func(self.update_func('node'), g.nodes)
glg.register_update_func(self.update_func('edge'), lg.nodes)
glg.update_all()
# TODO efficiency
for node, reprs in g.nodes.items():
reprs.update(glg.nodes[node])
for node, reprs in lg.nodes.items():
reprs.update(glg.nodes[node])
class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, order, radius):
super().__init__()
self.module_list = nn.ModuleList([GLGModule(in_feats, out_feats, radius)
for i in range(order)])
def forward(self, pairs, fusions):
for module, (g, lg), glg in zip(self.module_list, pairs, fusions):
module(g, lg, glg)
for lhs, rhs in zip(pairs[:-1], pairs[1:]):
for node, reprs in lhs[1].nodes.items():
x_rhs = reprs['x']
reprs['x'] = x_rhs + rhs[0].nodes[node]['x']
rhs[0].nodes[node]['x'] += x_rhs
class GNN(nn.Module):
def __init__(self, feats, order, radius, n_classes):
super().__init__()
self.order = order
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(in_feats, out_feats, order, radius)
for in_feats, out_feats in zip(feats[:-1], feats[1:])])
@staticmethod
def line_graph(g):
lg = nx.line_graph(g)
glg = nx.DiGraph()
glg.add_nodes_from(g.nodes)
glg.add_nodes_from(lg.nodes)
for u, v in g.edges:
glg.add_edge(u, (u, v))
glg.add_edge((u, v), u)
glg.add_edge(v, (u, v))
glg.add_edge((u, v), v)
return lg, glg
@staticmethod
def nx2dgl(g):
deg_dict = dict(nx.degree(g))
z = sum(deg_dict.values())
dgl_g = G.DGLGraph(g)
for node, reprs in dgl_g.nodes.items():
reprs['degree'] = deg_dict[node]
reprs['x'] = th.full((1, 1), reprs['degree'] / z)
reprs.update(g.nodes[node])
return dgl_g
def forward(self, g):
"""
Parameters
----------
g : networkx.DiGraph
"""
pair_list, glg_list = [], []
dgl_g = self.nx2dgl(g)
origin = dgl_g
for i in range(self.order):
lg, glg = self.line_graph(g)
dgl_lg = self.nx2dgl(lg)
pair_list.append((dgl_g, copy.deepcopy(dgl_lg)))
glg_list.append(G.DGLGraph(glg))
g = lg
dgl_g = dgl_lg
for module in self.module_list:
module(pair_list, glg_list)
return self.linear(th.cat([reprs['x'] for reprs in origin.nodes.values()], 0))
"""
By Minjie
"""
from __future__ import division
import math
import numpy as np
import scipy.sparse as sp
import networkx as nx
import matplotlib.pyplot as plt
class SSBM:
def __init__(self, n, k, a=10.0, b=2.0, regime='constant', rng=None):
"""Symmetric Stochastic Block Model.
n - number of nodes
k - number of communities
a - probability scale for intra-community edge
b - probability scale for inter-community edge
regime - If "logaritm", this generates SSBM(n, k, a*log(n)/n, b*log(n)/n)
If "constant", this generates SSBM(n, k, a/n, b/n)
If "mixed", this generates SSBM(n, k, a*log(n)/n, b/n)
"""
self.n = n
self.k = k
if regime == 'logarithm':
if math.sqrt(a) - math.sqrt(b) >= math.sqrt(k):
print('SSBM model with possible exact recovery.')
else:
print('SSBM model with impossible exact recovery.')
self.a = a * math.log(n) / n
self.b = b * math.log(n) / n
elif regime == 'constant':
snr = (a - b) ** 2 / (k * (a + (k - 1) * b))
if snr > 1:
print('SSBM model with possible detection.')
else:
print('SSBM model that may not have detection (snr=%.5f).' % snr)
self.a = a / n
self.b = b / n
elif regime == 'mixed':
self.a = a * math.log(n) / n
self.b = b / n
else:
raise ValueError('Unknown regime: %s' % regime)
if rng is None:
self.rng = np.random.RandomState()
else:
self.rng = rng
self._graph = None
def generate(self):
self.generate_communities()
print('Finished generating communities.')
self.generate_edges()
print('Finished generating edges.')
def generate_communities(self):
nodes = list(range(self.n))
size = self.n // self.k
self.block_size = size
self.comm2node = [nodes[i*size:(i+1)*size] for i in range(self.k)]
self.node2comm = [nid // size for nid in range(self.n)]
def generate_edges(self):
# TODO: dedup edges
us = []
vs = []
# generate intra-comm edges
for i in range(self.k):
sp_mat = sp.random(self.block_size, self.block_size,
density=self.a,
random_state=self.rng,
data_rvs=lambda l: np.ones(l))
u = sp_mat.row + i * self.block_size
v = sp_mat.col + i * self.block_size
us.append(u)
vs.append(v)
# generate inter-comm edges
for i in range(self.k):
for j in range(self.k):
if i == j:
continue
sp_mat = sp.random(self.block_size, self.block_size,
density=self.b,
random_state=self.rng,
data_rvs=lambda l: np.ones(l))
u = sp_mat.row + i * self.block_size
v = sp_mat.col + j * self.block_size
us.append(u)
vs.append(v)
us = np.hstack(us)
vs = np.hstack(vs)
self.sp_mat = sp.coo_matrix((np.ones(us.shape[0]), (us, vs)), shape=(self.n, self.n))
@property
def graph(self):
if self._graph is None:
self._graph = nx.from_scipy_sparse_matrix(self.sp_mat, create_using=nx.DiGraph())
return self._graph
def plot(self):
x = self.sp_mat.row
y = self.sp_mat.col
plt.scatter(x, y, s=0.5, marker='.', c='k')
plt.savefig('ssbm-%d-%d.pdf' % (self.n, self.k))
plt.clf()
# plot out degree distribution
out_degree = [d for _, d in self.graph.out_degree().items()]
plt.hist(out_degree, 100, normed=True)
plt.savefig('ssbm-%d-%d_out_degree.pdf' % (self.n, self.k))
plt.clf()
if __name__ == '__main__':
n = 1000
k = 10
ssbm = SSBM(n, k, regime='mixed', a=4, b=1)
ssbm.generate()
g = ssbm.graph
print('#nodes:', g.number_of_nodes())
print('#edges:', g.number_of_edges())
#ssbm.plot()
#lg = nx.line_graph(g)
# plot degree distribution
#degree = [d for _, d in lg.degree().items()]
#plt.hist(degree, 100, normed=True)
#plt.savefig('lg<ssbm-%d-%d>_degree.pdf' % (n, k))
#plt.clf()
"""
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --order 3 --radius 3
"""
import argparse
import networkx as nx
import torch as th
import torch.nn as nn
import torch.optim as optim
import gnn
parser = argparse.ArgumentParser()
parser.add_argument('--features', nargs='+', type=int)
parser.add_argument('--gpu', type=int)
parser.add_argument('--n-classes', type=int)
parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-nodes', type=int)
parser.add_argument('--order', type=int)
parser.add_argument('--radius', type=int)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
g = nx.barabasi_albert_graph(args.n_nodes, 1).to_directed() # TODO SBM
y = th.multinomial(th.ones(args.n_classes), args.n_nodes, replacement=True)
network = gnn.GNN(args.features, args.order, args.radius, args.n_classes)
if cuda:
network.cuda()
ce = nn.CrossEntropyLoss()
adam = optim.Adam(network.parameters())
for i in range(args.n_iterations):
y_bar = network(g)
loss = ce(y_bar, y)
adam.zero_grad()
loss.backward()
adam.step()
print('[iteration %d]loss %f' % (i, loss))
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader as CorpusReader
import networkx as nx
import torch as th
class nx_BCT_Reader:
# Binary Constituency Tree constructor for networkx
def __init__(self, cuda=False,
fnames=['trees/train.txt', 'trees/dev.txt', 'trees/test.txt']):
# fnames must be three items which means the file path of train, validation, test set, respectively.
self.corpus = CorpusReader('.', fnames)
self.train = self.corpus.parsed_sents(fnames[0])
self.dev = self.corpus.parsed_sents(fnames[1])
self.test = self.corpus.parsed_sents(fnames[2])
self.vocab = {}
def _rec(node):
for child in node:
if isinstance(child[0], str) and child[0] not in self.vocab:
self.vocab[child[0]] = len(self.vocab)
elif isinstance(child, Tree):
_rec(child)
for sent in self.train:
_rec(sent)
self.default = len(self.vocab) + 1
self.LongTensor = th.cuda.LongTensor if cuda else th.LongTensor
self.FloatTensor = th.cuda.FloatTensor if cuda else th.FloatTensor
def create_BCT(self, root):
self.node_cnt = 0
self.G = nx.DiGraph()
def _rec(node, nx_node, depth=0):
for child in node:
node_id = str(self.node_cnt) + '_' + str(depth+1)
self.node_cnt += 1
# if isinstance(child[0], str) or isinstance(child[0], unicode):
if isinstance(child[0], str):
word = self.LongTensor([self.vocab.get(child[0], self.default)])
self.G.add_node(node_id, x=word, y=None)
else:
label = self.FloatTensor([[0] * 5])
label[0, int(child.label())] = 1
self.G.add_node(node_id, x=None, y=label)
if isinstance(child, Tree): #check illegal trees
_rec(child, node_id)
self.G.add_edge(node_id, nx_node)
self.G.add_node('0_0', x=None, y=None) # add root into nx Graph
_rec(root, '0_0')
return self.G
def generator(self, mode='train'):
assert mode in ['train', 'dev', 'test']
for s in self.__dict__[mode]:
yield self.create_BCT(s)
import argparse
import torch as th
import torch.optim as optim
import nx_SST
import tree_lstm
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--h-size', type=int, default=512)
parser.add_argument('--log-every', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--n-iterations', type=int, default=1000)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--x-size', type=int, default=256)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
reader = nx_SST.nx_BCT_Reader(cuda)
loader = reader.generator()
network = tree_lstm.NAryTreeLSTM(len(reader.vocab) + 1,
args.x_size, args.h_size, args.n_ary, 5)
if cuda:
network.cuda()
adagrad = optim.Adagrad(network.parameters(), args.lr)
for i in range(args.n_iterations):
nll = 0
for j in range(args.batch_size):
g = next(loader)
nll += network(g, train=True)
nll /= args.batch_size
adagrad.zero_grad()
nll.backward()
adagrad.step()
if (i + 1) % args.log_every == 0:
print('[iteration %d]cross-entropy loss: %f' % ((i + 1), nll))
"""
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import itertools
import networkx as nx
import dgl.graph as G
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class TreeLSTM(nn.Module):
def __init__(self, n_embeddings, x_size, h_size, n_classes):
super().__init__()
self.embedding = nn.Embedding(n_embeddings, x_size)
self.linear = nn.Linear(h_size, n_classes)
@staticmethod
def message_func(src, trg, _):
return {'h' : src.get('h'), 'c' : src.get('c')}
def leaf_update_func(self, node_reprs, edge_reprs):
x = node_reprs['x']
iou = th.mm(x, self.iou_w) + self.iou_b
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
def internal_update_func(self, node_reprs, edge_reprs):
raise NotImplementedError()
def readout_func(self, g, train):
if train:
h = th.cat([d['h'] for d in g.nodes.values() if d['y'] is not None], 0)
y = th.cat([d['y'] for d in g.nodes.values() if d['y'] is not None], 0)
log_p = F.log_softmax(self.linear(h), 1)
return -th.sum(y * log_p) / len(y)
else:
h = th.cat([reprs['h'] for reprs in g.nodes.values()], 0)
y_bar = th.max(self.linear(h), 1)[1]
# TODO
for reprs, z in zip(labelled_reprs, y_bar):
reprs['y_bar'] = z.item()
def forward(self, g, train=False):
"""
Parameters
----------
g : networkx.DiGraph
"""
assert any(d['y'] is not None for d in g.nodes.values()) # TODO
g = G.DGLGraph(g)
def update_func(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
if node_reprs['x'] is not None:
node_reprs['x'] = self.embedding(node_reprs['x'])
return node_reprs
g.register_message_func(self.message_func)
g.register_update_func(update_func, g.nodes)
g.update_all()
g.register_update_func(self.internal_update_func, g.nodes)
leaves = list(filter(lambda x: g.in_degree(x) == 0, g.nodes))
g.register_update_func(self.leaf_update_func, leaves)
iterator = []
frontier = [next(filter(lambda x: g.out_degree(x) == 0, g.nodes))]
while frontier:
src = sum([list(g.pred[x]) for x in frontier], [])
trg = sum([[x] * len(g.pred[x]) for x in frontier], [])
iterator.append((src, trg))
frontier = src
g.recvfrom(leaves)
g.propagate(reversed(iterator))
return self.readout_func(g, train)
class ChildSumTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
self.iou_w = nn.Parameter(th.randn(x_size, 3 * h_size)) # TODO initializer
self.iou_u = nn.Parameter(th.randn(h_size, 3 * h_size)) # TODO initializer
self.iou_b = nn.Parameter(th.zeros(1, 3 * h_size))
self.f_x = nn.Parameter(th.randn(x_size, h_size)) # TODO initializer
self.f_h = nn.Parameter(th.randn(h_size, h_size)) # TODO initializer
self.f_b = nn.Parameter(th.zeros(1, h_size))
def internal_update_func(self, node_reprs, edge_reprs):
x = node_reprs['x']
h_bar = sum(msg['h'] for msg in edge_reprs)
if x is None:
iou = th.mm(h_bar, self.iou_u) + self.iou_b
else:
iou = th.mm(x, self.iou_w) + th.mm(h_bar, self.iou_u) + self.iou_b
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
wx = th.mm(x, self.f_x).repeat(len(edge_reprs), 1) if x is not None else 0
uh = th.mm(th.cat([msg['h'] for msg in edge_reprs], 0), self.f_h)
f = th.sigmoid(wx + uh + f_b)
c = th.cat([msg['c'] for msg in edge_reprs], 0)
c = i * u + th.sum(f * c, 0)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
class NAryTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_ary, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
# TODO initializer
self.iou_w = nn.Parameter(th.randn(x_size, 3 * h_size))
self.iou_u = [nn.Parameter(th.randn(1, h_size, 3 * h_size)) for i in range(n_ary)]
self.iou_b = nn.Parameter(th.zeros(1, 3 * h_size))
# TODO initializer
self.f_x = nn.Parameter(th.randn(x_size, h_size))
self.f_h = [[nn.Parameter(th.randn(1, h_size, h_size))
for i in range(n_ary)] for i in range(n_ary)]
self.f_b = nn.Parameter(th.zeros(1, h_size))
def internal_update_func(self, node_reprs, edge_reprs):
assert len(edge_reprs) > 0
assert all(msg['h'] is not None and msg['c'] is not None for msg in edge_reprs)
x = node_reprs['x']
n_children = len(edge_reprs)
iou_wx = th.mm(x, self.iou_w) if x is not None else 0
iou_u = th.cat(self.iou_u[:n_children], 0)
iou_h = th.cat([msg['h'] for msg in edge_reprs], 0).unsqueeze(1)
iou_uh = th.sum(th.bmm(iou_h, iou_u), 0)
i, o, u = th.chunk(iou_wx + iou_uh + self.iou_b, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
f_wx = th.mm(x, self.f_x).repeat(n_children, 1) if x is not None else 0
f_h = iou_h.repeat(n_children, 1, 1)
f_u = th.cat(sum([self.f_h[i][:n_children] for i in range(n_children)], []), 0)
f_uh = th.sum(th.bmm(f_h, f_u).view(n_children, n_children, -1), 0)
f = th.sigmoid(f_wx + f_uh + self.f_b)
c = th.cat([msg['c'] for msg in edge_reprs], 0)
c = i * u + th.sum(f * c, 0, keepdim=True)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -143,7 +143,7 @@ class DGLGraph(DiGraph): ...@@ -143,7 +143,7 @@ class DGLGraph(DiGraph):
The update function should be compatible with following signature: The update function should be compatible with following signature:
(edge_reprs, node_reprs) -> node_reprs (node_reprs, edge_reprs) -> node_reprs
It computes the new node representations using the representations It computes the new node representations using the representations
of the in-coming edges (the same concept as messages) and the node of the in-coming edges (the same concept as messages) and the node
...@@ -167,7 +167,7 @@ class DGLGraph(DiGraph): ...@@ -167,7 +167,7 @@ class DGLGraph(DiGraph):
>>> g.register_update_func(ufunc) >>> g.register_update_func(ufunc)
Register for a specific node. Register for a specific node.
>>> g.register_update_func(ufunc, u) >>> g.register_update_func(ufunc, u) # TODO Not implemented
Register for multiple nodes. Register for multiple nodes.
>>> u = [u1, u2, u3, ...] >>> u = [u1, u2, u3, ...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment