Commit a9ffb59e authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[DOC][Model] Tree-LSTM example and tutorial (#146)

[DOC][Model] Tree-LSTM example and tutorial
parent 0452cc3c
# Tree-LSTM
This is a re-implementation of the following paper:
> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075)
> *Kai Sheng Tai, Richard Socher, and Christopher Manning*.
The provided implementation can achieve a test accuracy of 50.59 which is comparable with the result reported in the paper 51.0.
## Data
The script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) automatically, and you need to download the GloVe word vectors yourself. For the command line, you can use this.
```
wget http://nlp.stanford.edu/data/glove.840B.300d.zip
unzip glove.840B.300d.zip
```
## Usage
```
python train.py --gpu 0
```
## Speed Test
To enable fair comparison with [DyNet Tree-LSTM implementation](https://github.com/clab/dynet/tree/master/examples/treelstm), we set the batch size to 100.
```
python train.py --gpu 0 --batch-size 100
```
| Device | Framework | Speed(time per batch) |
|---------------------|-----------|-----------------------|
| GeForce GTX TITAN X | DGL | 7.23(±0.66)s |
...@@ -3,102 +3,152 @@ import time ...@@ -3,102 +3,152 @@ import time
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as INIT
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl import dgl
import dgl.data as data import dgl.data as data
import dgl.ndarray as nd
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
def main(args): def main(args):
np.random.seed(args.seed)
th.manual_seed(args.seed)
th.cuda.manual_seed(args.seed)
cuda = args.gpu >= 0 cuda = args.gpu >= 0
device = th.device('cuda:{}'.format(args.gpu)) if cuda else th.device('cpu')
if cuda: if cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
def _batcher(trees):
bg = dgl.batch(trees)
if cuda:
for key in bg.node_attr_schemes().keys():
bg.ndata[key] = bg.ndata[key].cuda()
return bg
trainset = data.SST() trainset = data.SST()
train_loader = DataLoader(dataset=trainset, train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=_batcher, collate_fn=data.SST.batcher(device),
shuffle=False, shuffle=True,
num_workers=0) num_workers=0)
#testset = data.SST(mode='test') devset = data.SST(mode='dev')
#test_loader = DataLoader(dataset=testset, dev_loader = DataLoader(dataset=devset,
# batch_size=100, batch_size=100,
# collate_fn=data.SST.batcher, collate_fn=data.SST.batcher(device),
# shuffle=False, shuffle=False,
# num_workers=0) num_workers=0)
testset = data.SST(mode='test')
test_loader = DataLoader(dataset=testset,
batch_size=100,
collate_fn=data.SST.batcher(device),
shuffle=False,
num_workers=0)
model = TreeLSTM(trainset.num_vocabs, model = TreeLSTM(trainset.num_vocabs,
args.x_size, args.x_size,
args.h_size, args.h_size,
trainset.num_classes, trainset.num_classes,
args.dropout) args.dropout,
if cuda: pretrained_emb = trainset.pretrained_emb).to(device)
model.cuda()
zero_initializer = lambda shape : th.zeros(shape).cuda()
else:
zero_initializer = th.zeros
print(model) print(model)
optimizer = optim.Adagrad(model.parameters(), params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs]
lr=args.lr, params_emb = list(model.embedding.parameters())
weight_decay=args.weight_decay)
optimizer = optim.Adagrad([
{'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay},
{'params':params_emb, 'lr':0.1*args.lr}])
dur = [] dur = []
for epoch in range(args.epochs): for epoch in range(args.epochs):
t_epoch = time.time() t_epoch = time.time()
for step, graph in enumerate(train_loader): model.train()
for step, batch in enumerate(train_loader):
g = batch.graph
n = g.number_of_nodes()
h = th.zeros((n, args.h_size)).to(device)
c = th.zeros((n, args.h_size)).to(device)
if step >= 3: if step >= 3:
t0 = time.time() t0 = time.time() # tik
label = graph.ndata.pop('y')
# traverse graph logits = model(batch, h, c)
logits = model(graph, zero_initializer, train=True)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, label) loss = F.nll_loss(logp, batch.label, reduction='elementwise_mean')
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if step >= 3: if step >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0) # tok
if step > 0 and step % args.log_every == 0: if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1) pred = th.argmax(logits, 1)
acc = th.sum(th.eq(label, pred)) acc = th.sum(th.eq(batch.label, pred))
mean_dur = np.mean(dur) root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | " root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format( print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
epoch, step, loss.item(), acc.item() / len(label), epoch, step, loss.item(), 1.0*acc.item()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur)))
mean_dur, args.batch_size / mean_dur)) print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch))
print("Epoch time(s):", time.time() - t_epoch)
# test on dev set
accs = []
root_accs = []
model.eval()
for step, batch in enumerate(dev_loader):
g = batch.graph
n = g.number_of_nodes()
with th.no_grad():
h = th.zeros((n, args.h_size)).to(device)
c = th.zeros((n, args.h_size)).to(device)
logits = model(batch, h, c)
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
epoch, dev_acc, dev_root_acc))
# test # test
#for step, batch in enumerate(test_loader): accs = []
# g = batch.graph root_accs = []
# n = g.number_of_nodes() model.eval()
# x = th.zeros((n, args.x_size)) for step, batch in enumerate(test_loader):
# h = th.zeros((n, args.h_size)) g = batch.graph
# c = th.zeros((n, args.h_size)) n = g.number_of_nodes()
# logits = model(batch, x, h, c, train=True) with th.no_grad():
# pred = th.argmax(logits, 1) h = th.zeros((n, args.h_size)).to(device)
# acc = th.sum(th.eq(batch.label, pred)) / len(batch.label) c = th.zeros((n, args.h_size)).to(device)
# print(acc.item()) logits = model(batch, h, c)
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
#lr decay
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
epoch, test_acc, test_root_acc))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1) parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--seed', type=int, default=12110)
parser.add_argument('--batch-size', type=int, default=25) parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--x-size', type=int, default=256) parser.add_argument('--x-size', type=int, default=300)
parser.add_argument('--h-size', type=int, default=256) parser.add_argument('--h-size', type=int, default=150)
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--log-every', type=int, default=5) parser.add_argument('--log-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05) parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--weight-decay', type=float, default=1e-4) parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--dropout', type=float, default=0.3)
args = parser.parse_args() args = parser.parse_args()
print(args)
main(args) main(args)
...@@ -9,41 +9,29 @@ import numpy as np ...@@ -9,41 +9,29 @@ import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
class ChildSumTreeLSTMCell(nn.Module): class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__() super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size) self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(h_size, 3 * h_size) self.U_iou = nn.Linear(2 * h_size, 3 * h_size)
self.W_f = nn.Linear(x_size, h_size) self.U_f = nn.Linear(2 * h_size, 2 * h_size)
self.U_f = nn.Linear(h_size, h_size)
self.rt = 0.
self.ut = 0.
def message_func(self, edges): def message_func(self, edges):
return {'h' : edges.src['h'], 'c' : edges.src['c']} return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes): def reduce_func(self, nodes):
# equation (2) h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
h_tild = th.sum(nodes.mailbox['h'], 1) f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
# equation (4) c = th.sum(f * nodes.mailbox['c'], 1)
wx = self.W_f(nodes.data['x']).unsqueeze(1) # shape: (B, 1, H) return {'iou': self.U_iou(h_cat), 'c': c}
uh = self.U_f(nodes.mailbox['h']) # shape: (B, deg, H)
f = th.sigmoid(wx + uh) # shape: (B, deg, H) def apply_node_func(self, nodes):
# equation (7) second term iou = nodes.data['iou']
c_tild = th.sum(f * nodes.mailbox['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild}
def apply_func(self, nodes):
# equation (3), (5), (6)
iou = self.W_iou(nodes.data['x']) + self.U_iou(nodes.data['h_tild'])
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7) c = i * u + nodes.data['c']
c = i * u + nodes.data['c_tild']
# equation (8)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
...@@ -54,105 +42,47 @@ class TreeLSTM(nn.Module): ...@@ -54,105 +42,47 @@ class TreeLSTM(nn.Module):
h_size, h_size,
num_classes, num_classes,
dropout, dropout,
cell_type='childsum'): pretrained_emb=None):
super(TreeLSTM, self).__init__() super(TreeLSTM, self).__init__()
self.x_size = x_size self.x_size = x_size
self.h_size = h_size
# TODO(minjie): pre-trained embedding like GLoVe
self.embedding = nn.Embedding(num_vocabs, x_size) self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
print('Using glove')
self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes) self.linear = nn.Linear(h_size, num_classes)
if cell_type == 'childsum': self.cell = TreeLSTMCell(x_size, h_size)
self.cell = ChildSumTreeLSTMCell(x_size, h_size)
else:
raise RuntimeError('Unknown cell type:', cell_type)
def forward(self, graph, zero_initializer, h=None, c=None, train=True): def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
Parameters Parameters
---------- ----------
graph : dgl.DGLGraph batch : dgl.data.SSTBatch
The batched trees. The data batch.
zero_initializer : callable h : Tensor
Function to return zero value tensor.
h : Tensor, optional
Initial hidden state. Initial hidden state.
c : Tensor, optional c : Tensor
Initial cell state. Initial cell state.
iterator : graph iterator, optional
External iterator on graph.
Returns Returns
------- -------
logits : Tensor logits : Tensor
The prediction of each node. The prediction of each node.
""" """
g = graph g = batch.graph
n = g.number_of_nodes()
g.register_message_func(self.cell.message_func) g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func) g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_func) g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding # feed embedding
wordid = g.pop_n_repr('x') embeds = self.embedding(batch.wordid * batch.mask)
mask = (wordid != dgl.data.SST.PAD_WORD) g.ndata['iou'] = self.cell.W_iou(embeds) * batch.mask.float().unsqueeze(-1)
wordid = wordid * mask.long()
embeds = self.embedding(wordid)
g.ndata['x'] = embeds * th.unsqueeze(mask, 1).float()
if h is None:
h = zero_initializer((n, self.h_size))
g.ndata['h'] = h g.ndata['h'] = h
g.ndata['h_tild'] = zero_initializer((n, self.h_size))
if c is None:
c = zero_initializer((n, self.h_size))
g.ndata['c'] = c g.ndata['c'] = c
g.ndata['c_tild'] = zero_initializer((n, self.h_size)) # propagate
dgl.prop_nodes_topo(g) dgl.prop_nodes_topo(g)
# compute logits # compute logits
h = g.ndata.pop('h') h = self.dropout(g.ndata.pop('h'))
h = self.dropout(h)
logits = self.linear(h) logits = self.linear(h)
return logits return logits
'''
class NAryTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_ary, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
# TODO initializer
self.iou_w = nn.Parameter(th.randn(x_size, 3 * h_size))
self.iou_u = [nn.Parameter(th.randn(1, h_size, 3 * h_size)) for i in range(n_ary)]
self.iou_b = nn.Parameter(th.zeros(1, 3 * h_size))
# TODO initializer
self.f_x = nn.Parameter(th.randn(x_size, h_size))
self.f_h = [[nn.Parameter(th.randn(1, h_size, h_size))
for i in range(n_ary)] for i in range(n_ary)]
self.f_b = nn.Parameter(th.zeros(1, h_size))
def internal_update_func(self, node_reprs, edge_reprs):
assert len(edge_reprs) > 0
assert all(msg['h'] is not None and msg['c'] is not None for msg in edge_reprs)
x = node_reprs['x']
n_children = len(edge_reprs)
iou_wx = th.mm(x, self.iou_w) if x is not None else 0
iou_u = th.cat(self.iou_u[:n_children], 0)
iou_h = th.cat([msg['h'] for msg in edge_reprs], 0).unsqueeze(1)
iou_uh = th.sum(th.bmm(iou_h, iou_u), 0)
i, o, u = th.chunk(iou_wx + iou_uh + self.iou_b, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
f_wx = th.mm(x, self.f_x).repeat(n_children, 1) if x is not None else 0
f_h = iou_h.repeat(n_children, 1, 1)
f_u = th.cat(sum([self.f_h[i][:n_children] for i in range(n_children)], []), 0)
f_uh = th.sum(th.bmm(f_h, f_u).view(n_children, n_children, -1), 0)
f = th.sigmoid(f_wx + f_uh + self.f_b)
c = th.cat([msg['c'] for msg in edge_reprs], 0)
c = i * u + th.sum(f * c, 0, keepdim=True)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
'''
...@@ -5,19 +5,23 @@ Including: ...@@ -5,19 +5,23 @@ Including:
""" """
from __future__ import absolute_import from __future__ import absolute_import
from collections import namedtuple from collections import namedtuple, OrderedDict
from nltk.tree import Tree from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader from nltk.corpus.reader import BracketParseCorpusReader
import networkx as nx import networkx as nx
from .. import backend as F import numpy as np
from ..graph import DGLGraph import os
from .utils import download, extract_archive, get_download_dir import dgl
import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir
_urls = { _urls = {
'sst' : 'https://www.dropbox.com/s/dw8kr2vuq7k4dqi/sst.zip?dl=1', 'sst' : 'https://www.dropbox.com/s/6qa8rm43r2nmbyw/sst.zip?dl=1',
} }
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
class SST(object): class SST(object):
"""Stanford Sentiment Treebank dataset. """Stanford Sentiment Treebank dataset.
...@@ -42,10 +46,13 @@ class SST(object): ...@@ -42,10 +46,13 @@ class SST(object):
Optional vocabulary file. Optional vocabulary file.
""" """
PAD_WORD=-1 # special pad word id PAD_WORD=-1 # special pad word id
UNK_WORD=-1 # out-of-vocabulary word id
def __init__(self, mode='train', vocab_file=None): def __init__(self, mode='train', vocab_file=None):
self.mode = mode self.mode = mode
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/sst.zip'.format(self.dir) self.zip_file_path='{}/sst.zip'.format(self.dir)
self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else ''
self.pretrained_emb = None
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
download(_urls['sst'], path=self.zip_file_path) download(_urls['sst'], path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir)) extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
...@@ -56,37 +63,60 @@ class SST(object): ...@@ -56,37 +63,60 @@ class SST(object):
print('Dataset creation finished. #Trees:', len(self.trees)) print('Dataset creation finished. #Trees:', len(self.trees))
def _load(self): def _load(self):
files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader('{}/sst'.format(self.dir), files)
sents = corpus.parsed_sents(files[0])
# load vocab file # load vocab file
self.vocab = {} self.vocab = OrderedDict()
with open(self.vocab_file) as vf: with open(self.vocab_file, encoding='utf-8') as vf:
for line in vf.readlines(): for line in vf.readlines():
line = line.strip() line = line.strip()
self.vocab[line] = len(self.vocab) self.vocab[line] = len(self.vocab)
# filter glove
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
glove_emb = {}
with open(self.pretrained_file, 'r', encoding='utf-8') as pf:
for line in pf.readlines():
sp = line.split(' ')
if sp[0].lower() in self.vocab:
glove_emb[sp[0].lower()] = np.array([float(x) for x in sp[1:]])
files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader('{}/sst'.format(self.dir), files)
sents = corpus.parsed_sents(files[0])
#initialize with glove
pretrained_emb = []
fail_cnt = 0
for line in self.vocab.keys():
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
if not line.lower() in glove_emb:
fail_cnt += 1
pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300)))
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
self.pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
print('Miss word in GloVe {0:.4f}'.format(1.0*fail_cnt/len(self.pretrained_emb)))
# build trees # build trees
for sent in sents: for sent in sents:
self.trees.append(self._build_tree(sent)) self.trees.append(self._build_tree(sent))
def _build_tree(self, root): def _build_tree(self, root):
g = nx.DiGraph() g = nx.DiGraph()
def _rec_build(nid, node): def _rec_build(nid, node):
for child in node: for child in node:
cid = g.number_of_nodes() cid = g.number_of_nodes()
if isinstance(child[0], str): if isinstance(child[0], str) or isinstance(child[0], bytes):
# leaf node # leaf node
word = self.vocab[child[0].lower()] word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label())) g.add_node(cid, x=word, y=int(child.label()), mask=(word!=self.UNK_WORD))
else: else:
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label())) g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()), mask=0)
_rec_build(cid, child) _rec_build(cid, child)
g.add_edge(cid, nid) g.add_edge(cid, nid)
# add root # add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label())) g.add_node(0, x=SST.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root) _rec_build(0, root)
ret = DGLGraph() ret = dgl.DGLGraph()
ret.from_networkx(g, node_attrs=['x', 'y']) ret.from_networkx(g, node_attrs=['x', 'y', 'mask'])
return ret return ret
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -95,6 +125,16 @@ class SST(object): ...@@ -95,6 +125,16 @@ class SST(object):
def __len__(self): def __len__(self):
return len(self.trees) return len(self.trees)
@property @property
def num_vocabs(self): def num_vocabs(self):
return len(self.vocab) return len(self.vocab)
@staticmethod
def batcher(device):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'].to(device))
return batcher_dev
"""
.. _model-tree-lstm:
Tree LSTM DGL Tutorial
=========================
**Author**: `Zihao Ye`, `Qipeng Guo`, `Minjie Wang`, `Zheng Zhang`
"""
##############################################################################
#
# Tree-LSTM structure was first introduced by Kai et. al in their ACL 2015
# paper: `Improved Semantic Representations From Tree-Structured Long
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__,
# aiming to introduce syntactic information in the network by extending
# chain structured LSTM to tree structured LSTM, and uses Dependency
# Tree/Constituency Tree as the latent tree structure.
#
# The difficulty of training Tree-LSTM is that trees have different shape,
# making it difficult to parallelize. DGL offers a neat alternative. The
# key points are pooling all the trees into one graph, and then induce
# message passing over them.
#
# The task and the dataset
# ------------------------
#
# We will use Tree-LSTM for sentiment analysis task. We have wrapped the
# `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in
# ``dgl.data``. The dataset provides a fine-grained tree level sentiment
# annotation: 5 classes(very negative, negative, neutral, positive, and
# very positive) that indicates the sentiment in current subtree. Non-leaf
# nodes in constituency tree does not contain words, we use a special
# ``PAD_WORD`` token to denote them, during the training/inferencing,
# their embeddings would be masked to all-zero.
#
# .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png
# :alt:
#
# The figure displays one sample of the SST dataset, which is a
# constituency parse tree with their nodes labeled with sentiment. To
# speed up things, let's build a tiny set with 5 sentences and take a look
# at the first one:
#
import dgl
import dgl.data as data
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is a int value stored in the "x" field.
# The non-leaf nodes has a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = data.SST(mode='tiny') # the "tiny" set has only 5 trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes
vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word
a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():
if token != trainset.PAD_WORD:
print(inv_vocab[token], end=" ")
##############################################################################
# Step 1: batching
# ----------------
#
# The first step is to throw all the trees into one graph, using
# the :func:`~dgl.batched_graph.batch` API.
#
import networkx as nx
import matplotlib.pyplot as plt
graph = dgl.batch(tiny_sst)
def plot_tree(g):
# this plot requires pygraphviz package
pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
nx.draw(g, pos, with_labels=False, node_size=10,
node_color=[[.5, .5, .5]], arrowsize=4)
plt.show()
plot_tree(graph.to_networkx())
##############################################################################
# You can read more about the definition of :func:`~dgl.batched_graph.batch`
# (by clicking the API), or can skip ahead to the next step:
#
# .. note::
#
# **Definition**: a :class:`~dgl.batched_graph.BatchedDGLGraph` is a
# :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s.
#
# - The union includes all the nodes,
# edges, and their features. The order of nodes, edges and features are
# preserved.
#
# - Given that we have :math:`V_i` nodes for graph
# :math:`\mathcal{G}_i`, the node ID :math:`j` in graph
# :math:`\mathcal{G}_i` correspond to node ID
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
#
# - Therefore, performing feature transformation and message passing on
# ``BatchedDGLGraph`` is equivalent to doing those
# on all ``DGLGraph`` constituents in parallel.
#
# - Duplicate references to the same graph are
# treated as deep copies; the nodes, edges, and features are duplicated,
# and mutation on one reference does not affect the other.
# - Currently, ``BatchedDGLGraph`` is immutable in
# graph structure (i.e. one can't add
# nodes and edges to it). We need to support mutable batched graphs in
# (far) future.
# - The ``BatchedDGLGraph`` keeps track of the meta
# information of the constituents so it can be
# :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
#
# For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`
# module in DGL, you can click the class name.
#
# Step 2: Tree-LSTM Cell with message-passing APIs
# ------------------------------------------------
#
# .. note::
# The paper proposed two types of Tree LSTM: Child-Sum
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus on
# the later one. We use PyTorch as our backend framework to set up the
# network.
#
# In Tree LSTM, each unit at node :math:`j` maintains a hidden
# representation :math:`h_j` and a memory cell :math:`c_j`. The unit
# :math:`j` takes the input vector :math:`x_j` and the hidden
# representations of the their child units: :math:`h_k, k\in C(j)` as
# input, then compute its new hidden representation :math:`h_j` and memory
# cell :math:`c_j` in the following way.
#
# .. math::
#
# i_j = \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), \\
# f_{jk} = \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), \\
# o_j = \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), \\
# u_j = \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right) , \\
# c_j = i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, \\
# h_j = o_j \cdot \textrm{tanh}(c_j), \\
#
# The process can be decomposed into three phases: ``message_func``,
# ``reduce_func`` and ``apply_node_func``.
#
# ``apply_node_func`` is a new node UDF we have not introduced before. In
# ``apply_node_func``, user specifies what to do with node features,
# without considering edge features and messages. In Tree-LSTM case,
# ``apply_node_func`` is a must, since there exists (leaf) nodes with
# :math:`0` incoming edges, which would not be updated via
# ``reduce_func``.
#
import torch as th
import torch.nn as nn
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size)
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes):
iou = nodes.data['iou']
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
##############################################################################
# Step 3: define traversal
# ------------------------
#
# After defining the message passing functions, we then need to induce the
# right order to trigger them. This is a significant departure from models
# such as GCN, where all nodes are pulling messages from upstream ones
# *simultaneously*.
#
# In the case of Tree-LSTM, messages start from leaves of the tree, and
# propogate/processed upwards until they reach the roots. A visulization
# is as follows:
#
# .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif
# :alt:
#
# DGL defines a generator to perform the topological sort, each item is a
# tensor recording the nodes from bottom level to the roots. One can
# appreciate the degree of parallelism by inspecting the difference of the
# followings:
#
print('Traversing one tree:')
print(dgl.topological_nodes_generator(a_tree))
print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph))
##############################################################################
# We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
#
# .. note::
#
# Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a
# `message_func` and `reduce_func` in advance, here we use built-in
# copy-from-source and sum function as our message function and reduce
# function for demonstration.
import dgl.function as fn
import torch as th
graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
graph.register_message_func(fn.copy_src('a', 'a'))
graph.register_reduce_func(fn.sum('a', 'a'))
traversal_order = dgl.topological_nodes_generator(graph)
graph.prop_nodes(traversal_order)
# the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph)
##############################################################################
# Putting it together
# -------------------
#
# Here is the complete code that specifies the ``Tree-LSTM`` class:
#
class TreeLSTM(nn.Module):
def __init__(self,
num_vocabs,
x_size,
h_size,
num_classes,
dropout,
pretrained_emb=None):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
print('Using glove')
self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
self.cell = TreeLSTMCell(x_size, h_size)
def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
h : Tensor
Initial hidden state.
c : Tensor
Initial cell state.
Returns
-------
logits : Tensor
The prediction of each node.
"""
g = batch.graph
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(embeds) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
dgl.prop_nodes_topo(g)
# compute logits
h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h)
return logits
##############################################################################
# Main Loop
# ---------
#
# Finally, we could write a training paradigm in PyTorch:
#
from torch.utils.data import DataLoader
import torch.nn.functional as F
device = th.device('cpu')
# hyper parameters
x_size = 256
h_size = 256
dropout = 0.5
lr = 0.05
weight_decay = 1e-4
epochs = 10
# create the model
model = TreeLSTM(trainset.num_vocabs,
x_size,
h_size,
trainset.num_classes,
dropout)
print(model)
# create the optimizer
optimizer = th.optim.Adagrad(model.parameters(),
lr=lr,
weight_decay=weight_decay)
train_loader = DataLoader(dataset=tiny_sst,
batch_size=5,
collate_fn=data.SST.batcher(device),
shuffle=False,
num_workers=0)
# training loop
for epoch in range(epochs):
for step, batch in enumerate(train_loader):
g = batch.graph
n = g.number_of_nodes()
h = th.zeros((n, h_size))
c = th.zeros((n, h_size))
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='elementwise_mean')
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = th.argmax(logits, 1)
acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
epoch, step, loss.item(), acc))
##############################################################################
# To train the model on full dataset with different settings(CPU/GPU,
# etc.), please refer to our repo's
# `example <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/tree_lstm>`__.
...@@ -4,3 +4,4 @@ numpy ...@@ -4,3 +4,4 @@ numpy
seaborn seaborn
matplotlib matplotlib
pygraphviz pygraphviz
graphviz
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment