"src/array/vscode:/vscode.git/clone" did not exist on "f7b4c93d7d4abbf447c8375ffd90485332f97850"
Commit b9073209 authored by GaiYu0's avatar GaiYu0 Committed by Minjie Wang
Browse files

CDGNN example & minor fixes in dgl/graph.py (#11)

* cdgnn example & minor fixes in dgl/graph.py

* incomplete TreeLSTM

* add data preprocess of SST

* fix nx_SST.py

* TreeLSTM completed
parent c87564d5
import torch.utils as utils
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415
Deviations from paper:
- Addition of global aggregation operator.
- Message passing is equivalent to `A^j \cdot X`, instead of `\min(1, A^j) \cdot X`.
"""
# TODO self-loop?
# TODO in-place edit of node_reprs/edge_reprs in message_func/update_func?
# TODO batch-norm
import copy
import itertools
import dgl.graph as G
import networkx as nx
import torch as th
import torch.nn as nn
class GLGModule(nn.Module):
__SHADOW__ = 'shadow'
def __init__(self, in_feats, out_feats, radius):
super().__init__()
self.radius = radius
new_linear = lambda: nn.Linear(in_feats, out_feats)
new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
self.theta_x, self.theta_y, self.theta_deg, self.theta_global = \
new_linear(), new_linear(), new_linear(), new_linear()
self.theta_list = new_module_list()
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global = \
new_linear(), new_linear(), new_linear(), new_linear()
self.gamma_list = new_module_list()
@staticmethod
def copy(which):
if which == 'src':
return lambda src, trg, _: src.copy()
elif which == 'trg':
return lambda src, trg, _: trg.copy()
@staticmethod
def aggregate(msg_fld, trg_fld, normalize=False):
def a(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = sum(msg[msg_fld] for msg in edge_reprs)
if normalize:
node_reprs[trg_fld] /= len(edge_reprs)
return node_reprs
return a
@staticmethod
def pull(msg_fld, trg_fld):
def p(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
node_reprs[trg_fld] = edge_reprs[0][msg_fld]
return node_reprs
return p
def local_aggregate(self, g):
def step():
g.register_message_func(self.copy('src'), g.edges)
g.register_update_func(self.aggregate('x', 'x'), g.nodes)
g.update_all()
step()
for reprs in g.nodes.values():
reprs[0] = reprs['x']
for i in range(1, self.radius):
for j in range(2 ** (i - 1)):
step()
for reprs in g.nodes.values():
reprs[i] = reprs['x']
@staticmethod
def global_aggregate(g):
shadow = GLGModule.__SHADOW__
copy, aggregate, pull = GLGModule.copy, GLGModule.aggregate, GLGModule.pull
node_list = list(g.nodes)
uv_list = [(node, shadow) for node in g.nodes]
vu_list = [(shadow, node) for node in g.nodes]
g.add_node(shadow) # TODO context manager
tuple(itertools.starmap(g.add_edge, uv_list))
g.register_message_func(copy('src'), uv_list)
g.register_update_func(aggregate('x', 'global', normalize=True), (shadow,))
g.update_to(shadow)
tuple(itertools.starmap(g.add_edge, vu_list))
g.register_message_func(copy('src'), vu_list)
g.register_update_func(pull('global', 'global'), node_list)
g.update_from(shadow)
g.remove_node(shadow)
@staticmethod
def multiply_by_degree(g):
g.register_message_func(lambda *args: None, g.edges)
def update_func(node_reprs, _):
node_reprs = node_reprs.copy()
node_reprs['deg'] = node_reprs['x'] * node_reprs['degree']
return node_reprs
g.register_update_func(update_func, g.nodes)
g.update_all()
@staticmethod
def message_func(src, trg, _):
return {'y' : src['x']}
def update_func(self, which):
if which == 'node':
linear_x, linear_y, linear_deg, linear_global = \
self.theta_x, self.theta_y, self.theta_deg, self.theta_global
linear_list = self.theta_list
elif which == 'edge':
linear_x, linear_y, linear_deg, linear_global = \
self.gamma_x, self.gamma_y, self.gamma_deg, self.gamma_global
linear_list = self.gamma_list
def u(node_reprs, edge_reprs):
edge_reprs = filter(lambda x: x is not None, edge_reprs)
y = sum(x['y'] for x in edge_reprs)
node_reprs = node_reprs.copy()
node_reprs['x'] = linear_x(node_reprs['x']) \
+ linear_y(y) \
+ linear_deg(node_reprs['deg']) \
+ linear_global(node_reprs['global']) \
+ sum(linear(node_reprs[i]) \
for i, linear in enumerate(linear_list))
return node_reprs
return u
def forward(self, g, lg, glg):
self.local_aggregate(g)
self.local_aggregate(lg)
self.global_aggregate(g)
self.global_aggregate(lg)
self.multiply_by_degree(g)
self.multiply_by_degree(lg)
# TODO efficiency
for node, reprs in g.nodes.items():
glg.nodes[node].update(reprs)
for node, reprs in lg.nodes.items():
glg.nodes[node].update(reprs)
glg.register_message_func(self.message_func, glg.edges)
glg.register_update_func(self.update_func('node'), g.nodes)
glg.register_update_func(self.update_func('edge'), lg.nodes)
glg.update_all()
# TODO efficiency
for node, reprs in g.nodes.items():
reprs.update(glg.nodes[node])
for node, reprs in lg.nodes.items():
reprs.update(glg.nodes[node])
class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, order, radius):
super().__init__()
self.module_list = nn.ModuleList([GLGModule(in_feats, out_feats, radius)
for i in range(order)])
def forward(self, pairs, fusions):
for module, (g, lg), glg in zip(self.module_list, pairs, fusions):
module(g, lg, glg)
for lhs, rhs in zip(pairs[:-1], pairs[1:]):
for node, reprs in lhs[1].nodes.items():
x_rhs = reprs['x']
reprs['x'] = x_rhs + rhs[0].nodes[node]['x']
rhs[0].nodes[node]['x'] += x_rhs
class GNN(nn.Module):
def __init__(self, feats, order, radius, n_classes):
super().__init__()
self.order = order
self.linear = nn.Linear(feats[-1], n_classes)
self.module_list = nn.ModuleList([GNNModule(in_feats, out_feats, order, radius)
for in_feats, out_feats in zip(feats[:-1], feats[1:])])
@staticmethod
def line_graph(g):
lg = nx.line_graph(g)
glg = nx.DiGraph()
glg.add_nodes_from(g.nodes)
glg.add_nodes_from(lg.nodes)
for u, v in g.edges:
glg.add_edge(u, (u, v))
glg.add_edge((u, v), u)
glg.add_edge(v, (u, v))
glg.add_edge((u, v), v)
return lg, glg
@staticmethod
def nx2dgl(g):
deg_dict = dict(nx.degree(g))
z = sum(deg_dict.values())
dgl_g = G.DGLGraph(g)
for node, reprs in dgl_g.nodes.items():
reprs['degree'] = deg_dict[node]
reprs['x'] = th.full((1, 1), reprs['degree'] / z)
reprs.update(g.nodes[node])
return dgl_g
def forward(self, g):
"""
Parameters
----------
g : networkx.DiGraph
"""
pair_list, glg_list = [], []
dgl_g = self.nx2dgl(g)
origin = dgl_g
for i in range(self.order):
lg, glg = self.line_graph(g)
dgl_lg = self.nx2dgl(lg)
pair_list.append((dgl_g, copy.deepcopy(dgl_lg)))
glg_list.append(G.DGLGraph(glg))
g = lg
dgl_g = dgl_lg
for module in self.module_list:
module(pair_list, glg_list)
return self.linear(th.cat([reprs['x'] for reprs in origin.nodes.values()], 0))
"""
By Minjie
"""
from __future__ import division
import math
import numpy as np
import scipy.sparse as sp
import networkx as nx
import matplotlib.pyplot as plt
class SSBM:
def __init__(self, n, k, a=10.0, b=2.0, regime='constant', rng=None):
"""Symmetric Stochastic Block Model.
n - number of nodes
k - number of communities
a - probability scale for intra-community edge
b - probability scale for inter-community edge
regime - If "logaritm", this generates SSBM(n, k, a*log(n)/n, b*log(n)/n)
If "constant", this generates SSBM(n, k, a/n, b/n)
If "mixed", this generates SSBM(n, k, a*log(n)/n, b/n)
"""
self.n = n
self.k = k
if regime == 'logarithm':
if math.sqrt(a) - math.sqrt(b) >= math.sqrt(k):
print('SSBM model with possible exact recovery.')
else:
print('SSBM model with impossible exact recovery.')
self.a = a * math.log(n) / n
self.b = b * math.log(n) / n
elif regime == 'constant':
snr = (a - b) ** 2 / (k * (a + (k - 1) * b))
if snr > 1:
print('SSBM model with possible detection.')
else:
print('SSBM model that may not have detection (snr=%.5f).' % snr)
self.a = a / n
self.b = b / n
elif regime == 'mixed':
self.a = a * math.log(n) / n
self.b = b / n
else:
raise ValueError('Unknown regime: %s' % regime)
if rng is None:
self.rng = np.random.RandomState()
else:
self.rng = rng
self._graph = None
def generate(self):
self.generate_communities()
print('Finished generating communities.')
self.generate_edges()
print('Finished generating edges.')
def generate_communities(self):
nodes = list(range(self.n))
size = self.n // self.k
self.block_size = size
self.comm2node = [nodes[i*size:(i+1)*size] for i in range(self.k)]
self.node2comm = [nid // size for nid in range(self.n)]
def generate_edges(self):
# TODO: dedup edges
us = []
vs = []
# generate intra-comm edges
for i in range(self.k):
sp_mat = sp.random(self.block_size, self.block_size,
density=self.a,
random_state=self.rng,
data_rvs=lambda l: np.ones(l))
u = sp_mat.row + i * self.block_size
v = sp_mat.col + i * self.block_size
us.append(u)
vs.append(v)
# generate inter-comm edges
for i in range(self.k):
for j in range(self.k):
if i == j:
continue
sp_mat = sp.random(self.block_size, self.block_size,
density=self.b,
random_state=self.rng,
data_rvs=lambda l: np.ones(l))
u = sp_mat.row + i * self.block_size
v = sp_mat.col + j * self.block_size
us.append(u)
vs.append(v)
us = np.hstack(us)
vs = np.hstack(vs)
self.sp_mat = sp.coo_matrix((np.ones(us.shape[0]), (us, vs)), shape=(self.n, self.n))
@property
def graph(self):
if self._graph is None:
self._graph = nx.from_scipy_sparse_matrix(self.sp_mat, create_using=nx.DiGraph())
return self._graph
def plot(self):
x = self.sp_mat.row
y = self.sp_mat.col
plt.scatter(x, y, s=0.5, marker='.', c='k')
plt.savefig('ssbm-%d-%d.pdf' % (self.n, self.k))
plt.clf()
# plot out degree distribution
out_degree = [d for _, d in self.graph.out_degree().items()]
plt.hist(out_degree, 100, normed=True)
plt.savefig('ssbm-%d-%d_out_degree.pdf' % (self.n, self.k))
plt.clf()
if __name__ == '__main__':
n = 1000
k = 10
ssbm = SSBM(n, k, regime='mixed', a=4, b=1)
ssbm.generate()
g = ssbm.graph
print('#nodes:', g.number_of_nodes())
print('#edges:', g.number_of_edges())
#ssbm.plot()
#lg = nx.line_graph(g)
# plot degree distribution
#degree = [d for _, d in lg.degree().items()]
#plt.hist(degree, 100, normed=True)
#plt.savefig('lg<ssbm-%d-%d>_degree.pdf' % (n, k))
#plt.clf()
"""
ipython3 test.py -- --features 1 16 16 --gpu -1 --n-classes 5 --n-iterations 10 --n-nodes 10 --order 3 --radius 3
"""
import argparse
import networkx as nx
import torch as th
import torch.nn as nn
import torch.optim as optim
import gnn
parser = argparse.ArgumentParser()
parser.add_argument('--features', nargs='+', type=int)
parser.add_argument('--gpu', type=int)
parser.add_argument('--n-classes', type=int)
parser.add_argument('--n-iterations', type=int)
parser.add_argument('--n-nodes', type=int)
parser.add_argument('--order', type=int)
parser.add_argument('--radius', type=int)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
g = nx.barabasi_albert_graph(args.n_nodes, 1).to_directed() # TODO SBM
y = th.multinomial(th.ones(args.n_classes), args.n_nodes, replacement=True)
network = gnn.GNN(args.features, args.order, args.radius, args.n_classes)
if cuda:
network.cuda()
ce = nn.CrossEntropyLoss()
adam = optim.Adam(network.parameters())
for i in range(args.n_iterations):
y_bar = network(g)
loss = ce(y_bar, y)
adam.zero_grad()
loss.backward()
adam.step()
print('[iteration %d]loss %f' % (i, loss))
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader as CorpusReader
import networkx as nx
import torch as th
class nx_BCT_Reader:
# Binary Constituency Tree constructor for networkx
def __init__(self, cuda=False,
fnames=['trees/train.txt', 'trees/dev.txt', 'trees/test.txt']):
# fnames must be three items which means the file path of train, validation, test set, respectively.
self.corpus = CorpusReader('.', fnames)
self.train = self.corpus.parsed_sents(fnames[0])
self.dev = self.corpus.parsed_sents(fnames[1])
self.test = self.corpus.parsed_sents(fnames[2])
self.vocab = {}
def _rec(node):
for child in node:
if isinstance(child[0], str) and child[0] not in self.vocab:
self.vocab[child[0]] = len(self.vocab)
elif isinstance(child, Tree):
_rec(child)
for sent in self.train:
_rec(sent)
self.default = len(self.vocab) + 1
self.LongTensor = th.cuda.LongTensor if cuda else th.LongTensor
self.FloatTensor = th.cuda.FloatTensor if cuda else th.FloatTensor
def create_BCT(self, root):
self.node_cnt = 0
self.G = nx.DiGraph()
def _rec(node, nx_node, depth=0):
for child in node:
node_id = str(self.node_cnt) + '_' + str(depth+1)
self.node_cnt += 1
# if isinstance(child[0], str) or isinstance(child[0], unicode):
if isinstance(child[0], str):
word = self.LongTensor([self.vocab.get(child[0], self.default)])
self.G.add_node(node_id, x=word, y=None)
else:
label = self.FloatTensor([[0] * 5])
label[0, int(child.label())] = 1
self.G.add_node(node_id, x=None, y=label)
if isinstance(child, Tree): #check illegal trees
_rec(child, node_id)
self.G.add_edge(node_id, nx_node)
self.G.add_node('0_0', x=None, y=None) # add root into nx Graph
_rec(root, '0_0')
return self.G
def generator(self, mode='train'):
assert mode in ['train', 'dev', 'test']
for s in self.__dict__[mode]:
yield self.create_BCT(s)
import argparse
import torch as th
import torch.optim as optim
import nx_SST
import tree_lstm
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--h-size', type=int, default=512)
parser.add_argument('--log-every', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--n-iterations', type=int, default=1000)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--x-size', type=int, default=256)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
reader = nx_SST.nx_BCT_Reader(cuda)
loader = reader.generator()
network = tree_lstm.NAryTreeLSTM(len(reader.vocab) + 1,
args.x_size, args.h_size, args.n_ary, 5)
if cuda:
network.cuda()
adagrad = optim.Adagrad(network.parameters(), args.lr)
for i in range(args.n_iterations):
nll = 0
for j in range(args.batch_size):
g = next(loader)
nll += network(g, train=True)
nll /= args.batch_size
adagrad.zero_grad()
nll.backward()
adagrad.step()
if (i + 1) % args.log_every == 0:
print('[iteration %d]cross-entropy loss: %f' % ((i + 1), nll))
"""
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import itertools
import networkx as nx
import dgl.graph as G
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class TreeLSTM(nn.Module):
def __init__(self, n_embeddings, x_size, h_size, n_classes):
super().__init__()
self.embedding = nn.Embedding(n_embeddings, x_size)
self.linear = nn.Linear(h_size, n_classes)
@staticmethod
def message_func(src, trg, _):
return {'h' : src.get('h'), 'c' : src.get('c')}
def leaf_update_func(self, node_reprs, edge_reprs):
x = node_reprs['x']
iou = th.mm(x, self.iou_w) + self.iou_b
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
def internal_update_func(self, node_reprs, edge_reprs):
raise NotImplementedError()
def readout_func(self, g, train):
if train:
h = th.cat([d['h'] for d in g.nodes.values() if d['y'] is not None], 0)
y = th.cat([d['y'] for d in g.nodes.values() if d['y'] is not None], 0)
log_p = F.log_softmax(self.linear(h), 1)
return -th.sum(y * log_p) / len(y)
else:
h = th.cat([reprs['h'] for reprs in g.nodes.values()], 0)
y_bar = th.max(self.linear(h), 1)[1]
# TODO
for reprs, z in zip(labelled_reprs, y_bar):
reprs['y_bar'] = z.item()
def forward(self, g, train=False):
"""
Parameters
----------
g : networkx.DiGraph
"""
assert any(d['y'] is not None for d in g.nodes.values()) # TODO
g = G.DGLGraph(g)
def update_func(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
if node_reprs['x'] is not None:
node_reprs['x'] = self.embedding(node_reprs['x'])
return node_reprs
g.register_message_func(self.message_func)
g.register_update_func(update_func, g.nodes)
g.update_all()
g.register_update_func(self.internal_update_func, g.nodes)
leaves = list(filter(lambda x: g.in_degree(x) == 0, g.nodes))
g.register_update_func(self.leaf_update_func, leaves)
iterator = []
frontier = [next(filter(lambda x: g.out_degree(x) == 0, g.nodes))]
while frontier:
src = sum([list(g.pred[x]) for x in frontier], [])
trg = sum([[x] * len(g.pred[x]) for x in frontier], [])
iterator.append((src, trg))
frontier = src
g.recvfrom(leaves)
g.propagate(reversed(iterator))
return self.readout_func(g, train)
class ChildSumTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
self.iou_w = nn.Parameter(th.randn(x_size, 3 * h_size)) # TODO initializer
self.iou_u = nn.Parameter(th.randn(h_size, 3 * h_size)) # TODO initializer
self.iou_b = nn.Parameter(th.zeros(1, 3 * h_size))
self.f_x = nn.Parameter(th.randn(x_size, h_size)) # TODO initializer
self.f_h = nn.Parameter(th.randn(h_size, h_size)) # TODO initializer
self.f_b = nn.Parameter(th.zeros(1, h_size))
def internal_update_func(self, node_reprs, edge_reprs):
x = node_reprs['x']
h_bar = sum(msg['h'] for msg in edge_reprs)
if x is None:
iou = th.mm(h_bar, self.iou_u) + self.iou_b
else:
iou = th.mm(x, self.iou_w) + th.mm(h_bar, self.iou_u) + self.iou_b
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
wx = th.mm(x, self.f_x).repeat(len(edge_reprs), 1) if x is not None else 0
uh = th.mm(th.cat([msg['h'] for msg in edge_reprs], 0), self.f_h)
f = th.sigmoid(wx + uh + f_b)
c = th.cat([msg['c'] for msg in edge_reprs], 0)
c = i * u + th.sum(f * c, 0)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
class NAryTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_ary, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
# TODO initializer
self.iou_w = nn.Parameter(th.randn(x_size, 3 * h_size))
self.iou_u = [nn.Parameter(th.randn(1, h_size, 3 * h_size)) for i in range(n_ary)]
self.iou_b = nn.Parameter(th.zeros(1, 3 * h_size))
# TODO initializer
self.f_x = nn.Parameter(th.randn(x_size, h_size))
self.f_h = [[nn.Parameter(th.randn(1, h_size, h_size))
for i in range(n_ary)] for i in range(n_ary)]
self.f_b = nn.Parameter(th.zeros(1, h_size))
def internal_update_func(self, node_reprs, edge_reprs):
assert len(edge_reprs) > 0
assert all(msg['h'] is not None and msg['c'] is not None for msg in edge_reprs)
x = node_reprs['x']
n_children = len(edge_reprs)
iou_wx = th.mm(x, self.iou_w) if x is not None else 0
iou_u = th.cat(self.iou_u[:n_children], 0)
iou_h = th.cat([msg['h'] for msg in edge_reprs], 0).unsqueeze(1)
iou_uh = th.sum(th.bmm(iou_h, iou_u), 0)
i, o, u = th.chunk(iou_wx + iou_uh + self.iou_b, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
f_wx = th.mm(x, self.f_x).repeat(n_children, 1) if x is not None else 0
f_h = iou_h.repeat(n_children, 1, 1)
f_u = th.cat(sum([self.f_h[i][:n_children] for i in range(n_children)], []), 0)
f_uh = th.sum(th.bmm(f_h, f_u).view(n_children, n_children, -1), 0)
f = th.sigmoid(f_wx + f_uh + self.f_b)
c = th.cat([msg['c'] for msg in edge_reprs], 0)
c = i * u + th.sum(f * c, 0, keepdim=True)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
......@@ -143,7 +143,7 @@ class DGLGraph(DiGraph):
The update function should be compatible with following signature:
(edge_reprs, node_reprs) -> node_reprs
(node_reprs, edge_reprs) -> node_reprs
It computes the new node representations using the representations
of the in-coming edges (the same concept as messages) and the node
......@@ -167,7 +167,7 @@ class DGLGraph(DiGraph):
>>> g.register_update_func(ufunc)
Register for a specific node.
>>> g.register_update_func(ufunc, u)
>>> g.register_update_func(ufunc, u) # TODO Not implemented
Register for multiple nodes.
>>> u = [u1, u2, u3, ...]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment