Commit c80dc5e0 authored by Zihao Ye's avatar Zihao Ye Committed by Lingfan Yu
Browse files

[model] Tree-LSTM update & hotfix(dropbox link) (#202)

* change the signature of node/edge filter

* upd filter

* Support multi-dimension node feature in SPMV

* stable version

* hotfix

* upd tutorial

* upd README
parent 3a0f86a6
......@@ -4,7 +4,7 @@ This is a re-implementation of the following paper:
> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075)
> *Kai Sheng Tai, Richard Socher, and Christopher Manning*.
The provided implementation can achieve a test accuracy of 50.59 which is comparable with the result reported in the paper 51.0.
The provided implementation can achieve a test accuracy of 51.72 which is comparable with the result reported in the original paper: 51.0(±0.5).
## Data
The script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) automatically, and you need to download the GloVe word vectors yourself. For the command line, you can use this.
......@@ -19,11 +19,5 @@ python train.py --gpu 0
```
## Speed Test
To enable fair comparison with [DyNet Tree-LSTM implementation](https://github.com/clab/dynet/tree/master/examples/treelstm), we set the batch size to 100.
```
python train.py --gpu 0 --batch-size 100
```
| Device | Framework | Speed(time per batch) |
|---------------------|-----------|-----------------------|
| GeForce GTX TITAN X | DGL | 7.23(±0.66)s |
See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc .
......@@ -17,6 +17,9 @@ def main(args):
th.manual_seed(args.seed)
th.cuda.manual_seed(args.seed)
best_epoch = -1
best_dev_acc = 0
cuda = args.gpu >= 0
device = th.device('cuda:{}'.format(args.gpu)) if cuda else th.device('cpu')
if cuda:
......@@ -37,10 +40,7 @@ def main(args):
testset = data.SST(mode='test')
test_loader = DataLoader(dataset=testset,
batch_size=100,
collate_fn=data.SST.batcher(device),
shuffle=False,
num_workers=0)
batch_size=100, collate_fn=data.SST.batcher(device), shuffle=False, num_workers=0)
model = TreeLSTM(trainset.num_vocabs,
args.x_size,
......@@ -53,6 +53,10 @@ def main(args):
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs]
params_emb = list(model.embedding.parameters())
for p in params_ex_emb:
if p.dim() > 1:
INIT.xavier_uniform_(p)
optimizer = optim.Adagrad([
{'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay},
{'params':params_emb, 'lr':0.1*args.lr}])
......@@ -71,7 +75,8 @@ def main(args):
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='elementwise_mean')
loss = F.nll_loss(logp, batch.label, reduction='sum')
optimizer.zero_grad()
loss.backward()
optimizer.step()
......@@ -84,11 +89,12 @@ def main(args):
acc = th.sum(th.eq(batch.label, pred))
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
epoch, step, loss.item(), 1.0*acc.item()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur)))
print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch))
# test on dev set
# eval on dev set
accs = []
root_accs = []
model.eval()
......@@ -106,15 +112,27 @@ def main(args):
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
epoch, dev_acc, dev_root_acc))
if dev_root_acc > best_dev_acc:
best_dev_acc = dev_root_acc
best_epoch = epoch
th.save(model.state_dict(), 'best_{}.pkl'.format(args.seed))
else:
if best_epoch <= epoch - 10:
break
# lr decay
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
print(param_group['lr'])
# test
model.load_state_dict(th.load('best_{}.pkl'.format(args.seed)))
accs = []
root_accs = []
model.eval()
......@@ -132,19 +150,17 @@ def main(args):
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
#lr decay
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print('------------------------------------------------------------------------------------')
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
epoch, test_acc, test_root_acc))
best_epoch, test_acc, test_root_acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--seed', type=int, default=12110)
parser.add_argument('--seed', type=int, default=41)
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--child-sum', action='store_true')
parser.add_argument('--x-size', type=int, default=300)
......@@ -153,7 +169,7 @@ if __name__ == '__main__':
parser.add_argument('--log-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--dropout', type=float, default=0.5)
args = parser.parse_args()
print(args)
main(args)
......@@ -14,8 +14,9 @@ import dgl
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size)
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
......@@ -28,7 +29,7 @@ class TreeLSTMCell(nn.Module):
return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes):
iou = nodes.data['iou']
iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
......@@ -38,8 +39,9 @@ class TreeLSTMCell(nn.Module):
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(h_size, 3 * h_size)
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(h_size, h_size)
def message_func(self, edges):
......@@ -52,7 +54,7 @@ class ChildSumTreeLSTMCell(nn.Module):
return {'iou': self.U_iou(h_tild), 'c': c}
def apply_node_func(self, nodes):
iou = nodes.data['iou']
iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
......@@ -82,7 +84,6 @@ class TreeLSTM(nn.Module):
def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
......@@ -91,7 +92,6 @@ class TreeLSTM(nn.Module):
Initial hidden state.
c : Tensor
Initial cell state.
Returns
-------
logits : Tensor
......@@ -103,7 +103,7 @@ class TreeLSTM(nn.Module):
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(embeds) * batch.mask.float().unsqueeze(-1)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
......
......@@ -17,7 +17,7 @@ import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir
_urls = {
'sst' : 'https://www.dropbox.com/s/6qa8rm43r2nmbyw/sst.zip?dl=1',
'sst': 'https://www.dropbox.com/s/aqejdgrs3jkrmc7/sst.zip?dl=1',
}
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
......@@ -107,7 +107,7 @@ class SST(object):
if isinstance(child[0], str) or isinstance(child[0], bytes):
# leaf node
word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label()), mask=(word!=self.UNK_WORD))
g.add_node(cid, x=word, y=int(child.label()), mask=1)
else:
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()), mask=0)
_rec_build(cid, child)
......
......@@ -165,8 +165,9 @@ import torch.nn as nn
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size)
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
......@@ -183,7 +184,7 @@ class TreeLSTMCell(nn.Module):
def apply_node_func(self, nodes):
# equation (1), (3), (4)
iou = nodes.data['iou']
iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5)
......@@ -292,7 +293,7 @@ class TreeLSTM(nn.Module):
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(embeds) * batch.mask.float().unsqueeze(-1)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
......@@ -349,7 +350,7 @@ for epoch in range(epochs):
c = th.zeros((n, h_size))
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='elementwise_mean')
loss = F.nll_loss(logp, batch.label, reduction='sum')
optimizer.zero_grad()
loss.backward()
optimizer.step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment