Commit b7eb1659 authored by Gan Quan's avatar Gan Quan Committed by Minjie Wang
Browse files

Fix 0deg update_to and Tree-LSTM model (#51)

* WIP

* WIP

* treelstm dataloader

* Main training loop.

* trainable treelstm script

* fix dependency

* cuda training

* Add tensorized topological traversal

* allowing update_to() with no incoming messages

* fixing partial cases
parent 5e75f5db
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader as CorpusReader
import networkx as nx
import torch as th
class nx_BCT_Reader:
# Binary Constituency Tree constructor for networkx
def __init__(self, cuda=False,
fnames=['trees/train.txt', 'trees/dev.txt', 'trees/test.txt']):
# fnames must be three items which means the file path of train, validation, test set, respectively.
self.corpus = CorpusReader('.', fnames)
self.train = self.corpus.parsed_sents(fnames[0])
self.dev = self.corpus.parsed_sents(fnames[1])
self.test = self.corpus.parsed_sents(fnames[2])
self.vocab = {}
def _rec(node):
for child in node:
if isinstance(child[0], str) and child[0] not in self.vocab:
self.vocab[child[0]] = len(self.vocab)
elif isinstance(child, Tree):
_rec(child)
for sent in self.train:
_rec(sent)
self.default = len(self.vocab) + 1
self.LongTensor = th.cuda.LongTensor if cuda else th.LongTensor
self.FloatTensor = th.cuda.FloatTensor if cuda else th.FloatTensor
def create_BCT(self, root):
self.node_cnt = 0
self.G = nx.DiGraph()
def _rec(node, nx_node, depth=0):
for child in node:
node_id = str(self.node_cnt) + '_' + str(depth+1)
self.node_cnt += 1
# if isinstance(child[0], str) or isinstance(child[0], unicode):
if isinstance(child[0], str):
word = self.LongTensor([self.vocab.get(child[0], self.default)])
self.G.add_node(node_id, x=word, y=None)
else:
label = self.FloatTensor([[0] * 5])
label[0, int(child.label())] = 1
self.G.add_node(node_id, x=None, y=label)
if isinstance(child, Tree): #check illegal trees
_rec(child, node_id)
self.G.add_edge(node_id, nx_node)
self.G.add_node('0_0', x=None, y=None) # add root into nx Graph
_rec(root, '0_0')
return self.G
def generator(self, mode='train'):
assert mode in ['train', 'dev', 'test']
for s in self.__dict__[mode]:
yield self.create_BCT(s)
import argparse
import torch as th
import torch.optim as optim
import nx_SST
import tree_lstm
parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--h-size', type=int, default=512)
parser.add_argument('--log-every', type=int, default=1)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--n-iterations', type=int, default=1000)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--x-size', type=int, default=256)
args = parser.parse_args()
if args.gpu < 0:
cuda = False
else:
cuda = True
th.cuda.set_device(args.gpu)
reader = nx_SST.nx_BCT_Reader(cuda)
loader = reader.generator()
network = tree_lstm.NAryTreeLSTM(len(reader.vocab) + 1,
args.x_size, args.h_size, args.n_ary, 5)
if cuda:
network.cuda()
adagrad = optim.Adagrad(network.parameters(), args.lr)
for i in range(args.n_iterations):
nll = 0
for j in range(args.batch_size):
g = next(loader)
nll += network(g, train=True)
nll /= args.batch_size
adagrad.zero_grad()
nll.backward()
adagrad.step()
if (i + 1) % args.log_every == 0:
print('[iteration %d]cross-entropy loss: %f' % ((i + 1), nll))
import argparse
import time
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
import dgl.data as data
from tree_lstm import TreeLSTM
def _batch_to_cuda(batch):
return data.SSTBatch(graph=batch.graph,
nid_with_word = batch.nid_with_word.cuda(),
wordid = batch.wordid.cuda(),
label = batch.label.cuda())
import dgl.context as ctx
def tensor_topo_traverse(g, cuda, args):
n = g.number_of_nodes()
if cuda:
adjmat = g.cached_graph.adjmat(ctx.gpu(args.gpu))
mask = th.ones((n, 1)).cuda()
else:
adjmat = g.cached_graph.adjmat(ctx.cpu())
mask = th.ones((n, 1))
degree = th.spmm(adjmat, mask)
while th.sum(mask) != 0.:
v = (degree == 0.).float()
v = v * mask
mask = mask - v
frontier = th.squeeze(th.squeeze(v).nonzero(), 1)
yield frontier
degree -= th.spmm(adjmat, v)
def main(args):
cuda = args.gpu >= 0
if cuda:
th.cuda.set_device(args.gpu)
trainset = data.SST()
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
collate_fn=data.SST.batcher,
shuffle=False,
num_workers=0)
#testset = data.SST(mode='test')
#test_loader = DataLoader(dataset=testset,
# batch_size=100,
# collate_fn=data.SST.batcher,
# shuffle=False,
# num_workers=0)
model = TreeLSTM(trainset.num_vocabs,
args.x_size,
args.h_size,
trainset.num_classes,
args.dropout)
if cuda:
model.cuda()
print(model)
optimizer = optim.Adagrad(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
dur = []
for epoch in range(args.epochs):
t_epoch = time.time()
for step, batch in enumerate(train_loader):
if cuda:
batch = _batch_to_cuda(batch)
g = batch.graph
n = g.number_of_nodes()
x = th.zeros((n, args.x_size)).cuda()
h = th.zeros((n, args.h_size)).cuda()
c = th.zeros((n, args.h_size)).cuda()
if step >= 3:
t0 = time.time()
# traverse graph
giter = list(tensor_topo_traverse(g, False, args))
logits = model(batch, x, h, c, iterator=giter, train=True)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step >= 3:
dur.append(time.time() - t0)
if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred))
mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item()/len(batch.label),
mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch)
# test
#for step, batch in enumerate(test_loader):
# g = batch.graph
# n = g.number_of_nodes()
# x = th.zeros((n, args.x_size))
# h = th.zeros((n, args.h_size))
# c = th.zeros((n, args.h_size))
# logits = model(batch, x, h, c, train=True)
# pred = th.argmax(logits, 1)
# acc = th.sum(th.eq(batch.label, pred)) / len(batch.label)
# print(acc.item())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--x-size', type=int, default=256)
parser.add_argument('--h-size', type=int, default=256)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--log-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.5)
args = parser.parse_args()
main(args)
...@@ -2,123 +2,136 @@ ...@@ -2,123 +2,136 @@
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075 https://arxiv.org/abs/1503.00075
""" """
import time
import itertools import itertools
import networkx as nx import networkx as nx
import dgl.graph as G import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def topological_traverse(G):
class TreeLSTM(nn.Module): indegree_map = {v: d for v, d in G.in_degree() if d > 0}
def __init__(self, n_embeddings, x_size, h_size, n_classes): # These nodes have zero indegree and ready to be returned.
super().__init__() zero_indegree = [v for v, d in G.in_degree() if d == 0]
while True:
self.embedding = nn.Embedding(n_embeddings, x_size) yield zero_indegree
self.linear = nn.Linear(h_size, n_classes) next_zero_indegree = []
while zero_indegree:
@staticmethod node = zero_indegree.pop()
def message_func(src, trg, _): for _, child in G.edges(node):
return {'h' : src.get('h'), 'c' : src.get('c')} indegree_map[child] -= 1
if indegree_map[child] == 0:
def leaf_update_func(self, node_reprs, edge_reprs): next_zero_indegree.append(child)
x = node_reprs['x'] del indegree_map[child]
iou = th.mm(x, self.iou_w) + self.iou_b if len(next_zero_indegree) == 0:
break
zero_indegree = next_zero_indegree
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(h_size, 3 * h_size)
self.W_f = nn.Linear(x_size, h_size)
self.U_f = nn.Linear(h_size, h_size)
self.rt = 0.
self.ut = 0.
def message_func(self, src, edge):
return src
def reduce_func(self, node, msgs):
# equation (2)
h_tild = th.sum(msgs['h'], 1)
# equation (4)
wx = self.W_f(node['x']).unsqueeze(1) # shape: (B, 1, H)
uh = self.U_f(msgs['h']) # shape: (B, deg, H)
f = th.sigmoid(wx + uh) # shape: (B, deg, H)
# equation (7) second term
c_tild = th.sum(f * msgs['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild}
def update_func(self, node, accum):
# equation (3), (5), (6)
if accum is None:
iou = self.W_iou(node['x'])
else:
iou = self.W_iou(node['x']) + self.U_iou(accum['h_tild'])
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7)
if accum is None:
c = i * u c = i * u
else:
c = i * u + accum['c_tild']
# equation (8)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
def internal_update_func(self, node_reprs, edge_reprs): class TreeLSTM(nn.Module):
raise NotImplementedError() def __init__(self,
num_vocabs,
def readout_func(self, g, train): x_size,
if train: h_size,
h = th.cat([d['h'] for d in g.nodes.values() if d['y'] is not None], 0) num_classes,
y = th.cat([d['y'] for d in g.nodes.values() if d['y'] is not None], 0) dropout,
log_p = F.log_softmax(self.linear(h), 1) cell_type='childsum'):
return -th.sum(y * log_p) / len(y) super(TreeLSTM, self).__init__()
self.x_size = x_size
# TODO(minjie): pre-trained embedding like GLoVe
self.embedding = nn.Embedding(num_vocabs, x_size)
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
if cell_type == 'childsum':
self.cell = ChildSumTreeLSTMCell(x_size, h_size)
else: else:
h = th.cat([reprs['h'] for reprs in g.nodes.values()], 0) raise RuntimeError('Unknown cell type:', cell_type)
y_bar = th.max(self.linear(h), 1)[1]
# TODO def forward(self, batch, x, h, c, iterator=None, train=True):
for reprs, z in zip(labelled_reprs, y_bar): """Compute tree-lstm prediction given a batch.
reprs['y_bar'] = z.item()
def forward(self, g, train=False):
"""
Parameters Parameters
---------- ----------
g : networkx.DiGraph batch : dgl.data.SSTBatch
The data batch.
x : Tensor
Initial node input.
h : Tensor
Initial hidden state.
c : Tensor
Initial cell state.
iterator : graph iterator
External iterator on graph.
Returns
-------
logits : Tensor
The prediction of each node.
""" """
assert any(d['y'] is not None for d in g.nodes.values()) # TODO g = batch.graph
g.register_message_func(self.cell.message_func, batchable=True)
g = G.DGLGraph(g) g.register_reduce_func(self.cell.reduce_func, batchable=True)
g.register_update_func(self.cell.update_func, batchable=True)
def update_func(node_reprs, edge_reprs): # feed embedding
node_reprs = node_reprs.copy() embeds = self.embedding(batch.wordid)
if node_reprs['x'] is not None: x = x.index_copy(0, batch.nid_with_word, embeds)
node_reprs['x'] = self.embedding(node_reprs['x']) g.set_n_repr({'x' : x, 'h' : h, 'c' : c})
return node_reprs # TODO(minjie): potential bottleneck
if iterator is None:
g.register_message_func(self.message_func) for frontier in topological_traverse(g):
g.register_update_func(update_func, g.nodes) #print('frontier', frontier)
g.update_all() g.update_to(frontier)
g.register_update_func(self.internal_update_func, g.nodes)
leaves = list(filter(lambda x: g.in_degree(x) == 0, g.nodes))
g.register_update_func(self.leaf_update_func, leaves)
iterator = []
frontier = [next(filter(lambda x: g.out_degree(x) == 0, g.nodes))]
while frontier:
src = sum([list(g.pred[x]) for x in frontier], [])
trg = sum([[x] * len(g.pred[x]) for x in frontier], [])
iterator.append((src, trg))
frontier = src
g.recv(leaves)
g.propagate(reversed(iterator))
return self.readout_func(g, train)
class ChildSumTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
self.iou_w = nn.Parameter(th.randn(x_size, 3 * h_size)) # TODO initializer
self.iou_u = nn.Parameter(th.randn(h_size, 3 * h_size)) # TODO initializer
self.iou_b = nn.Parameter(th.zeros(1, 3 * h_size))
self.f_x = nn.Parameter(th.randn(x_size, h_size)) # TODO initializer
self.f_h = nn.Parameter(th.randn(h_size, h_size)) # TODO initializer
self.f_b = nn.Parameter(th.zeros(1, h_size))
def internal_update_func(self, node_reprs, edge_reprs):
x = node_reprs['x']
h_bar = sum(msg['h'] for msg in edge_reprs)
if x is None:
iou = th.mm(h_bar, self.iou_u) + self.iou_b
else: else:
iou = th.mm(x, self.iou_w) + th.mm(h_bar, self.iou_u) + self.iou_b for frontier in iterator:
i, o, u = th.chunk(iou, 3, 1) g.update_to(frontier)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) # compute logits
h = g.pop_n_repr('h')
wx = th.mm(x, self.f_x).repeat(len(edge_reprs), 1) if x is not None else 0 h = self.dropout(h)
uh = th.mm(th.cat([msg['h'] for msg in edge_reprs], 0), self.f_h) logits = self.linear(h)
f = th.sigmoid(wx + uh + f_b) return logits
c = th.cat([msg['c'] for msg in edge_reprs], 0) '''
c = i * u + th.sum(f * c, 0)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
class NAryTreeLSTM(TreeLSTM): class NAryTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_ary, n_classes): def __init__(self, n_embeddings, x_size, h_size, n_ary, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes) super().__init__(n_embeddings, x_size, h_size, n_classes)
...@@ -159,3 +172,4 @@ class NAryTreeLSTM(TreeLSTM): ...@@ -159,3 +172,4 @@ class NAryTreeLSTM(TreeLSTM):
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
'''
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
"""Classes and functions for batching multiple graphs together."""
from __future__ import absolute_import
import numpy as np
from dgl.graph import DGLGraph from dgl.graph import DGLGraph
import dgl.backend as F import dgl.backend as F
import dgl import dgl
import numpy as np
class BatchedDGLGraph(DGLGraph): class BatchedDGLGraph(DGLGraph):
def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr): def __init__(self, graph_list, node_attrs=None, edge_attrs=None, **attr):
......
...@@ -39,26 +39,68 @@ class CachedGraph: ...@@ -39,26 +39,68 @@ class CachedGraph:
return utils.toindex(eids) return utils.toindex(eids)
def in_edges(self, v): def in_edges(self, v):
"""Get in-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no in-edges.
"""
src = [] src = []
dst = [] dst = []
orphan = []
for vv in utils.node_iter(v): for vv in utils.node_iter(v):
uu = self._graph.predecessors(vv) uu = self._graph.predecessors(vv)
if len(uu) == 0:
orphan.append(vv)
else:
src += uu src += uu
dst += [vv] * len(uu) dst += [vv] * len(uu)
src = utils.toindex(src) src = utils.toindex(src)
dst = utils.toindex(dst) dst = utils.toindex(dst)
return src, dst orphan = utils.toindex(orphan)
return src, dst, orphan
def out_edges(self, u): def out_edges(self, u):
"""Get out-edges of the vertices.
Parameters
----------
v : utils.Index
The vertex ids.
Returns
-------
src : utils.Index
The src vertex ids.
dst : utils.Index
The dst vertex ids.
orphan : utils.Index
The vertice that have no out-edges.
"""
src = [] src = []
dst = [] dst = []
orphan = []
for uu in utils.node_iter(u): for uu in utils.node_iter(u):
vv = self._graph.successors(uu) vv = self._graph.successors(uu)
if len(vv) == 0:
orphan.append(uu)
else:
src += [uu] * len(vv) src += [uu] * len(vv)
dst += vv dst += vv
src = utils.toindex(src) src = utils.toindex(src)
dst = utils.toindex(dst) dst = utils.toindex(dst)
return src, dst orphan = utils.toindex(orphan)
return src, dst, orphan
def in_degrees(self, v): def in_degrees(self, v):
degs = self._graph.indegree(list(v)) degs = self._graph.indegree(list(v))
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from . import citation_graph as citegrh from . import citation_graph as citegrh
from .tree import *
from .utils import * from .utils import *
def register_data_args(parser): def register_data_args(parser):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment