Commit c80dc5e0 authored by Zihao Ye's avatar Zihao Ye Committed by Lingfan Yu
Browse files

[model] Tree-LSTM update & hotfix(dropbox link) (#202)

* change the signature of node/edge filter

* upd filter

* Support multi-dimension node feature in SPMV

* stable version

* hotfix

* upd tutorial

* upd README
parent 3a0f86a6
...@@ -4,7 +4,7 @@ This is a re-implementation of the following paper: ...@@ -4,7 +4,7 @@ This is a re-implementation of the following paper:
> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075) > [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075)
> *Kai Sheng Tai, Richard Socher, and Christopher Manning*. > *Kai Sheng Tai, Richard Socher, and Christopher Manning*.
The provided implementation can achieve a test accuracy of 50.59 which is comparable with the result reported in the paper 51.0. The provided implementation can achieve a test accuracy of 51.72 which is comparable with the result reported in the original paper: 51.0(±0.5).
## Data ## Data
The script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) automatically, and you need to download the GloVe word vectors yourself. For the command line, you can use this. The script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) automatically, and you need to download the GloVe word vectors yourself. For the command line, you can use this.
...@@ -19,11 +19,5 @@ python train.py --gpu 0 ...@@ -19,11 +19,5 @@ python train.py --gpu 0
``` ```
## Speed Test ## Speed Test
To enable fair comparison with [DyNet Tree-LSTM implementation](https://github.com/clab/dynet/tree/master/examples/treelstm), we set the batch size to 100.
```
python train.py --gpu 0 --batch-size 100
```
| Device | Framework | Speed(time per batch) | See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc .
|---------------------|-----------|-----------------------|
| GeForce GTX TITAN X | DGL | 7.23(±0.66)s |
...@@ -17,6 +17,9 @@ def main(args): ...@@ -17,6 +17,9 @@ def main(args):
th.manual_seed(args.seed) th.manual_seed(args.seed)
th.cuda.manual_seed(args.seed) th.cuda.manual_seed(args.seed)
best_epoch = -1
best_dev_acc = 0
cuda = args.gpu >= 0 cuda = args.gpu >= 0
device = th.device('cuda:{}'.format(args.gpu)) if cuda else th.device('cpu') device = th.device('cuda:{}'.format(args.gpu)) if cuda else th.device('cpu')
if cuda: if cuda:
...@@ -37,10 +40,7 @@ def main(args): ...@@ -37,10 +40,7 @@ def main(args):
testset = data.SST(mode='test') testset = data.SST(mode='test')
test_loader = DataLoader(dataset=testset, test_loader = DataLoader(dataset=testset,
batch_size=100, batch_size=100, collate_fn=data.SST.batcher(device), shuffle=False, num_workers=0)
collate_fn=data.SST.batcher(device),
shuffle=False,
num_workers=0)
model = TreeLSTM(trainset.num_vocabs, model = TreeLSTM(trainset.num_vocabs,
args.x_size, args.x_size,
...@@ -53,6 +53,10 @@ def main(args): ...@@ -53,6 +53,10 @@ def main(args):
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs] params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs]
params_emb = list(model.embedding.parameters()) params_emb = list(model.embedding.parameters())
for p in params_ex_emb:
if p.dim() > 1:
INIT.xavier_uniform_(p)
optimizer = optim.Adagrad([ optimizer = optim.Adagrad([
{'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay}, {'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay},
{'params':params_emb, 'lr':0.1*args.lr}]) {'params':params_emb, 'lr':0.1*args.lr}])
...@@ -71,7 +75,8 @@ def main(args): ...@@ -71,7 +75,8 @@ def main(args):
logits = model(batch, h, c) logits = model(batch, h, c)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='elementwise_mean') loss = F.nll_loss(logp, batch.label, reduction='sum')
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -84,11 +89,12 @@ def main(args): ...@@ -84,11 +89,12 @@ def main(args):
acc = th.sum(th.eq(batch.label, pred)) acc = th.sum(th.eq(batch.label, pred))
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids]) root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format( print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
epoch, step, loss.item(), 1.0*acc.item()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur))) epoch, step, loss.item(), 1.0*acc.item()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur)))
print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch)) print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch))
# test on dev set # eval on dev set
accs = [] accs = []
root_accs = [] root_accs = []
model.eval() model.eval()
...@@ -106,45 +112,55 @@ def main(args): ...@@ -106,45 +112,55 @@ def main(args):
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0] root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids]) root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)]) root_accs.append([root_acc, len(root_ids)])
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format( print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
epoch, dev_acc, dev_root_acc)) epoch, dev_acc, dev_root_acc))
# test if dev_root_acc > best_dev_acc:
accs = [] best_dev_acc = dev_root_acc
root_accs = [] best_epoch = epoch
model.eval() th.save(model.state_dict(), 'best_{}.pkl'.format(args.seed))
for step, batch in enumerate(test_loader): else:
g = batch.graph if best_epoch <= epoch - 10:
n = g.number_of_nodes() break
with th.no_grad():
h = th.zeros((n, args.h_size)).to(device)
c = th.zeros((n, args.h_size)).to(device)
logits = model(batch, h, c)
pred = th.argmax(logits, 1) # lr decay
acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
#lr decay
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10 param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
print(param_group['lr'])
# test
model.load_state_dict(th.load('best_{}.pkl'.format(args.seed)))
accs = []
root_accs = []
model.eval()
for step, batch in enumerate(test_loader):
g = batch.graph
n = g.number_of_nodes()
with th.no_grad():
h = th.zeros((n, args.h_size)).to(device)
c = th.zeros((n, args.h_size)).to(device)
logits = model(batch, h, c)
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format( print('------------------------------------------------------------------------------------')
epoch, test_acc, test_root_acc)) print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
best_epoch, test_acc, test_root_acc))
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1) parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--seed', type=int, default=12110) parser.add_argument('--seed', type=int, default=41)
parser.add_argument('--batch-size', type=int, default=25) parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--child-sum', action='store_true') parser.add_argument('--child-sum', action='store_true')
parser.add_argument('--x-size', type=int, default=300) parser.add_argument('--x-size', type=int, default=300)
...@@ -153,7 +169,7 @@ if __name__ == '__main__': ...@@ -153,7 +169,7 @@ if __name__ == '__main__':
parser.add_argument('--log-every', type=int, default=5) parser.add_argument('--log-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05) parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--weight-decay', type=float, default=1e-4) parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.3) parser.add_argument('--dropout', type=float, default=0.5)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(args) main(args)
...@@ -14,8 +14,9 @@ import dgl ...@@ -14,8 +14,9 @@ import dgl
class TreeLSTMCell(nn.Module): class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__() super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size) self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size) self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, 2 * h_size) self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges): def message_func(self, edges):
...@@ -28,7 +29,7 @@ class TreeLSTMCell(nn.Module): ...@@ -28,7 +29,7 @@ class TreeLSTMCell(nn.Module):
return {'iou': self.U_iou(h_cat), 'c': c} return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
iou = nodes.data['iou'] iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c'] c = i * u + nodes.data['c']
...@@ -38,8 +39,9 @@ class TreeLSTMCell(nn.Module): ...@@ -38,8 +39,9 @@ class TreeLSTMCell(nn.Module):
class ChildSumTreeLSTMCell(nn.Module): class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__() super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size) self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(h_size, 3 * h_size) self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(h_size, h_size) self.U_f = nn.Linear(h_size, h_size)
def message_func(self, edges): def message_func(self, edges):
...@@ -52,7 +54,7 @@ class ChildSumTreeLSTMCell(nn.Module): ...@@ -52,7 +54,7 @@ class ChildSumTreeLSTMCell(nn.Module):
return {'iou': self.U_iou(h_tild), 'c': c} return {'iou': self.U_iou(h_tild), 'c': c}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
iou = nodes.data['iou'] iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c'] c = i * u + nodes.data['c']
...@@ -82,7 +84,6 @@ class TreeLSTM(nn.Module): ...@@ -82,7 +84,6 @@ class TreeLSTM(nn.Module):
def forward(self, batch, h, c): def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
Parameters Parameters
---------- ----------
batch : dgl.data.SSTBatch batch : dgl.data.SSTBatch
...@@ -91,7 +92,6 @@ class TreeLSTM(nn.Module): ...@@ -91,7 +92,6 @@ class TreeLSTM(nn.Module):
Initial hidden state. Initial hidden state.
c : Tensor c : Tensor
Initial cell state. Initial cell state.
Returns Returns
------- -------
logits : Tensor logits : Tensor
...@@ -103,7 +103,7 @@ class TreeLSTM(nn.Module): ...@@ -103,7 +103,7 @@ class TreeLSTM(nn.Module):
g.register_apply_node_func(self.cell.apply_node_func) g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(embeds) * batch.mask.float().unsqueeze(-1) g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h g.ndata['h'] = h
g.ndata['c'] = c g.ndata['c'] = c
# propagate # propagate
......
...@@ -17,7 +17,7 @@ import dgl.backend as F ...@@ -17,7 +17,7 @@ import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir from dgl.data.utils import download, extract_archive, get_download_dir
_urls = { _urls = {
'sst' : 'https://www.dropbox.com/s/6qa8rm43r2nmbyw/sst.zip?dl=1', 'sst': 'https://www.dropbox.com/s/aqejdgrs3jkrmc7/sst.zip?dl=1',
} }
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label']) SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
...@@ -37,7 +37,7 @@ class SST(object): ...@@ -37,7 +37,7 @@ class SST(object):
.. note:: .. note::
All the samples will be loaded and preprocessed in the memory first. All the samples will be loaded and preprocessed in the memory first.
Parameters Parameters
---------- ----------
mode : str, optional mode : str, optional
...@@ -107,7 +107,7 @@ class SST(object): ...@@ -107,7 +107,7 @@ class SST(object):
if isinstance(child[0], str) or isinstance(child[0], bytes): if isinstance(child[0], str) or isinstance(child[0], bytes):
# leaf node # leaf node
word = self.vocab.get(child[0].lower(), self.UNK_WORD) word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label()), mask=(word!=self.UNK_WORD)) g.add_node(cid, x=word, y=int(child.label()), mask=1)
else: else:
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()), mask=0) g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()), mask=0)
_rec_build(cid, child) _rec_build(cid, child)
......
...@@ -165,8 +165,9 @@ import torch.nn as nn ...@@ -165,8 +165,9 @@ import torch.nn as nn
class TreeLSTMCell(nn.Module): class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__() super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size) self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size) self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, 2 * h_size) self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges): def message_func(self, edges):
...@@ -174,20 +175,20 @@ class TreeLSTMCell(nn.Module): ...@@ -174,20 +175,20 @@ class TreeLSTMCell(nn.Module):
def reduce_func(self, nodes): def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4) # concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
# equation (2) # equation (2)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
# second term of equation (5) # second term of equation (5)
c = th.sum(f * nodes.mailbox['c'], 1) c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c} return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
# equation (1), (3), (4) # equation (1), (3), (4)
iou = nodes.data['iou'] iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5) # equation (5)
c = i * u + nodes.data['c'] c = i * u + nodes.data['c']
# equation (6) # equation (6)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
...@@ -292,7 +293,7 @@ class TreeLSTM(nn.Module): ...@@ -292,7 +293,7 @@ class TreeLSTM(nn.Module):
g.register_apply_node_func(self.cell.apply_node_func) g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(embeds) * batch.mask.float().unsqueeze(-1) g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h g.ndata['h'] = h
g.ndata['c'] = c g.ndata['c'] = c
# propagate # propagate
...@@ -349,7 +350,7 @@ for epoch in range(epochs): ...@@ -349,7 +350,7 @@ for epoch in range(epochs):
c = th.zeros((n, h_size)) c = th.zeros((n, h_size))
logits = model(batch, h, c) logits = model(batch, h, c)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='elementwise_mean') loss = F.nll_loss(logp, batch.label, reduction='sum')
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment