Commit b7eb1659 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

Fix 0deg update_to and Tree-LSTM model (#51)

* WIP

* WIP

* treelstm dataloader

* Main training loop.

* trainable treelstm script

* fix dependency

* cuda training

* Add tensorized topological traversal

* allowing update_to() with no incoming messages

* fixing partial cases
parent 5e75f5db
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader as CorpusReader
import networkx as nx
import torch as th
class nx_BCT_Reader:
# Binary Constituency Tree constructor for networkx
def __init__(self, cuda=False,
fnames=['trees/train.txt', 'trees/dev.txt', 'trees/test.txt']):
# fnames must be three items which means the file path of train, validation, test set, respectively.
self.corpus = CorpusReader('.', fnames)
self.train = self.corpus.parsed_sents(fnames[0])
self.dev = self.corpus.parsed_sents(fnames[1])
self.test = self.corpus.parsed_sents(fnames[2])
self.vocab = {}
def _rec(node):
for child in node:
if isinstance(child[0], str) and child[0] not in self.vocab:
self.vocab[child[0]] = len(self.vocab)
elif isinstance(child, Tree):
_rec(child)
for sent in self.train:
_rec(sent)
self.default = len(self.vocab) + 1
self.LongTensor = th.cuda.LongTensor if cuda else th.LongTensor
self.FloatTensor = th.cuda.FloatTensor if cuda else th.FloatTensor
def create_BCT(self, root):
self.node_cnt = 0
self.G = nx.DiGraph()
def _rec(node, nx_node, depth=0):
for child in node:
node_id = str(self.node_cnt) + '_' + str(depth+1)
self.node_cnt += 1
# if isinstance(child[0], str) or isinstance(child[0], unicode):
if isinstance(child[0], str):
word = self.LongTensor([self.vocab.get(child[0], self.default)])
self.G.add_node(node_id, x=word, y=None)
else:
label = self.FloatTensor([[0] * 5])
label[0, int(child.label())] = 1
self.G.add_node(node_id, x=None, y=label)
if isinstance(child, Tree): #check illegal trees
_rec(child, node_id)
self.G.add_edge(node_id, nx_node)
self.G.add_node('0_0', x=None, y=None) # add root into nx Graph
_rec(root, '0_0')
return self.G
def generator(self, mode='train'):
assert mode in ['train', 'dev', 'test']
for s in self.__dict__[mode]:
yield self.create_BCT(s)
import argparse
import torch as th
import torch.optim as optim
import nx_SST
import tree_lstm
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--h-size', type=int, default=512)
parser.add_argument('--log-every', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--n-iterations', type=int, default=1000)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--x-size', type=int, default=256)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
reader = nx_SST.nx_BCT_Reader(cuda)
loader = reader.generator()
network = tree_lstm.NAryTreeLSTM(len(reader.vocab) + 1,
args.x_size, args.h_size, args.n_ary, 5)
if cuda:
network.cuda()
adagrad = optim.Adagrad(network.parameters(), args.lr)
for i in range(args.n_iterations):
nll = 0
for j in range(args.batch_size):
g = next(loader)
nll += network(g, train=True)
nll /= args.batch_size
adagrad.zero_grad()
nll.backward()
adagrad.step()
if (i + 1) % args.log_every == 0:
print('[iteration %d]cross-entropy loss: %f' % ((i + 1), nll))
import argparse
import time
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
import dgl.data as data
from tree_lstm import TreeLSTM
def _batch_to_cuda(batch):
return data.SSTBatch(graph=batch.graph,
nid_with_word = batch.nid_with_word.cuda(),
wordid = batch.wordid.cuda(),
label = batch.label.cuda())
import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g.cached_graph.adjmat(ctx.gpu(args.gpu))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g.cached_graph.adjmat(ctx.cpu())
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
v = (degree == 0.).float()
v = v * mask
mask = mask - v
frontier = th.squeeze(th.squeeze(v).nonzero(), 1)
yield frontier
degree -= th.spmm(adjmat, v)
def main(args):
cuda = args.gpu >= 0
if cuda:
th.cuda.set_device(args.gpu)
trainset = data.SST()
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
collate_fn=data.SST.batcher,
shuffle=False,
num_workers=0)
#testset = data.SST(mode='test')
#test_loader = DataLoader(dataset=testset,
# batch_size=100,
# collate_fn=data.SST.batcher,
# shuffle=False,
# num_workers=0)
model = TreeLSTM(trainset.num_vocabs,
args.x_size,
args.h_size,
trainset.num_classes,
args.dropout)
if cuda:
model.cuda()
print(model)
optimizer = optim.Adagrad(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
dur = []
for epoch in range(args.epochs):
t_epoch = time.time()
for step, batch in enumerate(train_loader):
if cuda:
batch = _batch_to_cuda(batch)
g = batch.graph
n = g.number_of_nodes()
x = th.zeros((n, args.x_size)).cuda()
h = th.zeros((n, args.h_size)).cuda()
c = th.zeros((n, args.h_size)).cuda()
if step >= 3:
t0 = time.time()
# traverse graph
giter = list(tensor_topo_traverse(g, False, args))
logits = model(batch, x, h, c, iterator=giter, train=True)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step >= 3:
dur.append(time.time() - t0)
if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred))
mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item()/len(batch.label),
mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch)
# test
#for step, batch in enumerate(test_loader):
# g = batch.graph
# n = g.number_of_nodes()
# x = th.zeros((n, args.x_size))
# h = th.zeros((n, args.h_size))
# c = th.zeros((n, args.h_size))
# logits = model(batch, x, h, c, train=True)
# pred = th.argmax(logits, 1)
# acc = th.sum(th.eq(batch.label, pred)) / len(batch.label)
# print(acc.item())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--x-size', type=int, default=256)
parser.add_argument('--h-size', type=int, default=256)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--log-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.5)
args = parser.parse_args()
main(args)
......@@ -2,123 +2,136 @@
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import time
import itertools
import networkx as nx
import dgl.graph as G
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
class TreeLSTM(nn.Module):
def __init__(self, n_embeddings, x_size, h_size, n_classes):
super().__init__()
self.embedding = nn.Embedding(n_embeddings, x_size)
self.linear = nn.Linear(h_size, n_classes)
@staticmethod
def message_func(src, trg, _):
return {'h' : src.get('h'), 'c' : src.get('c')}
def leaf_update_func(self, node_reprs, edge_reprs):
x = node_reprs['x']
iou = th.mm(x, self.iou_w) + self.iou_b
def topological_traverse(G):
indegree_map = {v: d for v, d in G.in_degree() if d > 0}
# These nodes have zero indegree and ready to be returned.
zero_indegree = [v for v, d in G.in_degree() if d == 0]
while True:
yield zero_indegree
next_zero_indegree = []
while zero_indegree:
node = zero_indegree.pop()
for _, child in G.edges(node):
indegree_map[child] -= 1
if indegree_map[child] == 0:
next_zero_indegree.append(child)
del indegree_map[child]
if len(next_zero_indegree) == 0:
break
zero_indegree = next_zero_indegree
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(h_size, 3 * h_size)
self.W_f = nn.Linear(x_size, h_size)
self.U_f = nn.Linear(h_size, h_size)
self.rt = 0.
self.ut = 0.
def message_func(self, src, edge):
return src
def reduce_func(self, node, msgs):
# equation (2)
h_tild = th.sum(msgs['h'], 1)
# equation (4)
wx = self.W_f(node['x']).unsqueeze(1) # shape: (B, 1, H)
uh = self.U_f(msgs['h']) # shape: (B, deg, H)
f = th.sigmoid(wx + uh) # shape: (B, deg, H)
# equation (7) second term
c_tild = th.sum(f * msgs['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild}
def update_func(self, node, accum):
# equation (3), (5), (6)
if accum is None:
iou = self.W_iou(node['x'])
else:
iou = self.W_iou(node['x']) + self.U_iou(accum['h_tild'])
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7)
if accum is None:
c = i * u
else:
c = i * u + accum['c_tild']
# equation (8)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
def internal_update_func(self, node_reprs, edge_reprs):
raise NotImplementedError()
def readout_func(self, g, train):
if train:
h = th.cat([d['h'] for d in g.nodes.values() if d['y'] is not None], 0)
y = th.cat([d['y'] for d in g.nodes.values() if d['y'] is not None], 0)
log_p = F.log_softmax(self.linear(h), 1)
return -th.sum(y * log_p) / len(y)
class TreeLSTM(nn.Module):
def __init__(self,
num_vocabs,
x_size,
h_size,
num_classes,
dropout,
cell_type='childsum'):
super(TreeLSTM, self).__init__()
self.x_size = x_size
# TODO(minjie): pre-trained embedding like GLoVe
self.embedding = nn.Embedding(num_vocabs, x_size)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
if cell_type == 'childsum':
self.cell = ChildSumTreeLSTMCell(x_size, h_size)
else:
h = th.cat([reprs['h'] for reprs in g.nodes.values()], 0)
y_bar = th.max(self.linear(h), 1)[1]
# TODO
for reprs, z in zip(labelled_reprs, y_bar):
reprs['y_bar'] = z.item()
raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, batch, x, h, c, iterator=None, train=True):
"""Compute tree-lstm prediction given a batch.
def forward(self, g, train=False):
"""
Parameters
----------
g : networkx.DiGraph
batch : dgl.data.SSTBatch
The data batch.
x : Tensor
Initial node input.
h : Tensor
Initial hidden state.
c : Tensor
Initial cell state.
iterator : graph iterator
External iterator on graph.
Returns
-------
logits : Tensor
The prediction of each node.
"""
assert any(d['y'] is not None for d in g.nodes.values()) # TODO
g = G.DGLGraph(g)
def update_func(node_reprs, edge_reprs):
node_reprs = node_reprs.copy()
if node_reprs['x'] is not None:
node_reprs['x'] = self.embedding(node_reprs['x'])
return node_reprs
g.register_message_func(self.message_func)
g.register_update_func(update_func, g.nodes)
g.update_all()
g.register_update_func(self.internal_update_func, g.nodes)
leaves = list(filter(lambda x: g.in_degree(x) == 0, g.nodes))
g.register_update_func(self.leaf_update_func, leaves)
iterator = []
frontier = [next(filter(lambda x: g.out_degree(x) == 0, g.nodes))]
while frontier:
src = sum([list(g.pred[x]) for x in frontier], [])
trg = sum([[x] * len(g.pred[x]) for x in frontier], [])
iterator.append((src, trg))
frontier = src
g.recv(leaves)
g.propagate(reversed(iterator))
return self.readout_func(g, train)
class ChildSumTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
self.iou_w = nn.Parameter(th.randn(x_size, 3 * h_size)) # TODO initializer
self.iou_u = nn.Parameter(th.randn(h_size, 3 * h_size)) # TODO initializer
self.iou_b = nn.Parameter(th.zeros(1, 3 * h_size))
self.f_x = nn.Parameter(th.randn(x_size, h_size)) # TODO initializer
self.f_h = nn.Parameter(th.randn(h_size, h_size)) # TODO initializer
self.f_b = nn.Parameter(th.zeros(1, h_size))
def internal_update_func(self, node_reprs, edge_reprs):
x = node_reprs['x']
h_bar = sum(msg['h'] for msg in edge_reprs)
if x is None:
iou = th.mm(h_bar, self.iou_u) + self.iou_b
g = batch.graph
g.register_message_func(self.cell.message_func, batchable=True)
g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_update_func(self.cell.update_func, batchable=True)
# feed embedding
embeds = self.embedding(batch.wordid)
x = x.index_copy(0, batch.nid_with_word, embeds)
g.set_n_repr({'x' : x, 'h' : h, 'c' : c})
# TODO(minjie): potential bottleneck
if iterator is None:
for frontier in topological_traverse(g):
#print('frontier', frontier)
g.update_to(frontier)
else:
iou = th.mm(x, self.iou_w) + th.mm(h_bar, self.iou_u) + self.iou_b
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
wx = th.mm(x, self.f_x).repeat(len(edge_reprs), 1) if x is not None else 0
uh = th.mm(th.cat([msg['h'] for msg in edge_reprs], 0), self.f_h)
f = th.sigmoid(wx + uh + f_b)
c = th.cat([msg['c'] for msg in edge_reprs], 0)
c = i * u + th.sum(f * c, 0)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
for frontier in iterator:
g.update_to(frontier)
# compute logits
h = g.pop_n_repr('h')
h = self.dropout(h)
logits = self.linear(h)
return logits
'''
class NAryTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_ary, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
......@@ -159,3 +172,4 @@ class NAryTreeLSTM(TreeLSTM):
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
'''
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
"""Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import
import numpy as np
from dgl.graph import DGLGraph
import dgl.backend as F
import dgl
import numpy as np
class BatchedDGLGraph(DGLGraph):
def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr):
......
......@@ -39,26 +39,68 @@ class CachedGraph:
return utils.toindex(eids)
def in_edges(self, v):
"""Get in-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no in-edges.
"""
src = []
dst = []
orphan = []
for vv in utils.node_iter(v):
uu = self._graph.predecessors(vv)
if len(uu) == 0:
orphan.append(vv)
else:
src += uu
dst += [vv] * len(uu)
src = utils.toindex(src)
dst = utils.toindex(dst)
return src, dst
orphan = utils.toindex(orphan)
return src, dst, orphan
def out_edges(self, u):
"""Get out-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no out-edges.
"""
src = []
dst = []
orphan = []
for uu in utils.node_iter(u):
vv = self._graph.successors(uu)
if len(vv) == 0:
orphan.append(uu)
else:
src += [uu] * len(vv)
dst += vv
src = utils.toindex(src)
dst = utils.toindex(dst)
return src, dst
orphan = utils.toindex(orphan)
return src, dst, orphan
def in_degrees(self, v):
degs = self._graph.indegree(list(v))
......
......@@ -2,6 +2,7 @@
from __future__ import absolute_import
from . import citation_graph as citegrh
from .tree import *
from .utils import *
def register_data_args(parser):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment