"graphbolt/vscode:/vscode.git/clone" did not exist on "7eb4de4bed6115e9d87be77586593dd4364b9be2"
Commit 1e50cd2e authored by Sheng Zha's avatar Sheng Zha Committed by Da Zheng
Browse files

[Model][MXNet] MXNet Tree LSTM example (#279)

* TreeLSTM MXNet example

* hybridize

* add glove download

* usability

* Update README.md
parent 534ea8c9
# Tree-LSTM
This is a re-implementation of the following paper:
> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075)
> *Kai Sheng Tai, Richard Socher, and Christopher Manning*.
The provided implementation can achieve a test accuracy of 51.72 which is comparable with the result reported in the original paper: 51.0(±0.5).
## Data
The script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) and the GloVe 840B.300d embedding automatically if `--use-glove` is specified (note: download may take a while).
## Usage
```
python train.py --gpu 0
```
## Speed Test
See https://docs.google.com/spreadsheets/d/1eCQrVn7g0uWriz63EbEDdes2ksMdKdlbWMyT8PSU4rc .
import argparse
import time
import warnings
import zipfile
import os
os.environ['DGLBACKEND'] = 'mxnet'
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
import numpy as np
import mxnet as mx
from mxnet import gluon
import dgl
import dgl.data as data
from tree_lstm import TreeLSTM
def batcher(ctx):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return data.SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].as_in_context(ctx),
wordid=batch_trees.ndata['x'].as_in_context(ctx),
label=batch_trees.ndata['y'].as_in_context(ctx))
return batcher_dev
def prepare_glove():
if not (os.path.exists('glove.840B.300d.txt')
and data.utils.check_sha1('glove.840B.300d.txt',
sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708')):
zip_path = data.utils.download('http://nlp.stanford.edu/data/glove.840B.300d.zip',
sha1_hash='8084fbacc2dee3b1fd1ca4cc534cbfff3519ed0d')
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall()
if not data.utils.check_sha1('glove.840B.300d.txt',
sha1_hash='294b9f37fa64cce31f9ebb409c266fc379527708'):
warnings.warn('The downloaded glove embedding file checksum mismatch. File content '
'may be corrupted.')
def main(args):
np.random.seed(args.seed)
mx.random.seed(args.seed)
best_epoch = -1
best_dev_acc = 0
cuda = args.gpu >= 0
if cuda:
if args.gpu in mx.test_utils.list_gpus():
ctx = mx.gpu(args.gpu)
else:
print('Requested GPU id {} was not found. Defaulting to CPU implementation'.format(args.gpu))
ctx = mx.cpu()
if args.use_glove:
prepare_glove()
trainset = data.SST()
train_loader = gluon.data.DataLoader(dataset=trainset,
batch_size=args.batch_size,
batchify_fn=batcher(ctx),
shuffle=True,
num_workers=0)
devset = data.SST(mode='dev')
dev_loader = gluon.data.DataLoader(dataset=devset,
batch_size=100,
batchify_fn=batcher(ctx),
shuffle=True,
num_workers=0)
testset = data.SST(mode='test')
test_loader = gluon.data.DataLoader(dataset=testset,
batch_size=100,
batchify_fn=batcher(ctx),
shuffle=False, num_workers=0)
model = TreeLSTM(trainset.num_vocabs,
args.x_size,
args.h_size,
trainset.num_classes,
args.dropout,
cell_type='childsum' if args.child_sum else 'nary',
pretrained_emb = trainset.pretrained_emb,
ctx=ctx)
print(model)
params_ex_emb =[x for x in model.collect_params().values()
if x.grad_req != 'null' and x.shape[0] != trainset.num_vocabs]
params_emb = list(model.embedding.collect_params().values())
for p in params_emb:
p.lr_mult = 0.1
model.initialize(mx.init.Xavier(magnitude=1), ctx=ctx)
model.hybridize()
trainer = gluon.Trainer(model.collect_params('^(?!embedding).*$'), 'adagrad',
{'learning_rate': args.lr, 'wd': args.weight_decay})
trainer_emb = gluon.Trainer(model.collect_params('^embedding.*$'), 'adagrad',
{'learning_rate': args.lr})
dur = []
L = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)
for epoch in range(args.epochs):
t_epoch = time.time()
for step, batch in enumerate(train_loader):
g = batch.graph
n = g.number_of_nodes()
# TODO begin_states function?
h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx)
if step >= 3:
t0 = time.time() # tik
with mx.autograd.record():
pred = model(batch, h, c)
loss = L(pred, batch.label)
loss.backward()
trainer.step(args.batch_size)
trainer_emb.step(args.batch_size)
if step >= 3:
dur.append(time.time() - t0) # tok
if step > 0 and step % args.log_every == 0:
pred = pred.argmax(axis=1)
acc = (batch.label == pred).sum()
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids])
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
epoch, step, loss.sum().asscalar(), 1.0*acc.asscalar()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur)))
print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch))
# eval on dev set
accs = []
root_accs = []
for step, batch in enumerate(dev_loader):
g = batch.graph
n = g.number_of_nodes()
h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx)
pred = model(batch, h, c).argmax(1)
acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
epoch, dev_acc, dev_root_acc))
if dev_root_acc > best_dev_acc:
best_dev_acc = dev_root_acc
best_epoch = epoch
model.save_parameters('best_{}.params'.format(args.seed))
else:
if best_epoch <= epoch - 10:
break
# lr decay
trainer.set_learning_rate(max(1e-5, trainer.learning_rate*0.99))
print(trainer.learning_rate)
trainer_emb.set_learning_rate(max(1e-5, trainer_emb.learning_rate*0.99))
print(trainer_emb.learning_rate)
# test
model.load_parameters('best_{}.params'.format(args.seed))
accs = []
root_accs = []
for step, batch in enumerate(test_loader):
g = batch.graph
n = g.number_of_nodes()
h = mx.nd.zeros((n, args.h_size), ctx=ctx)
c = mx.nd.zeros((n, args.h_size), ctx=ctx)
pred = model(batch, h, c).argmax(axis=1)
acc = (batch.label == pred).sum().asscalar()
accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.asnumpy()[root_ids] == pred.asnumpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print('------------------------------------------------------------------------------------')
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
best_epoch, test_acc, test_root_acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=int, default=41)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--child-sum', action='store_true')
parser.add_argument('--x-size', type=int, default=300)
parser.add_argument('--h-size', type=int, default=150)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--log-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--use-glove', action='store_true')
args = parser.parse_args()
print(args)
main(args)
"""
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import time
import itertools
import networkx as nx
import numpy as np
import mxnet as mx
from mxnet import gluon
import dgl
class _TreeLSTMCellNodeFunc(gluon.HybridBlock):
def hybrid_forward(self, F, iou, b_iou, c):
iou = F.broadcast_add(iou, b_iou)
i, o, u = iou.split(num_outputs=3, axis=1)
i, o, u = i.sigmoid(), o.sigmoid(), u.tanh()
c = i * u + c
h = o * c.tanh()
return h, c
class _TreeLSTMCellReduceFunc(gluon.HybridBlock):
def __init__(self, U_iou, U_f):
super(_TreeLSTMCellReduceFunc, self).__init__()
self.U_iou = U_iou
self.U_f = U_f
def hybrid_forward(self, F, h, c):
h_cat = h.reshape((0, -1))
f = self.U_f(h_cat).sigmoid().reshape_like(h)
c = (f * c).sum(axis=1)
iou = self.U_iou(h_cat)
return iou, c
class _TreeLSTMCell(gluon.HybridBlock):
def __init__(self, h_size):
super(_TreeLSTMCell, self).__init__()
self._apply_node_func = _TreeLSTMCellNodeFunc()
self.b_iou = self.params.get('bias', shape=(1, 3 * h_size),
init='zeros')
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
def apply_node_func(self, nodes):
iou = nodes.data['iou']
b_iou, c = self.b_iou.data(iou.context), nodes.data['c']
h, c = self._apply_node_func(iou, b_iou, c)
return {'h' : h, 'c' : c}
class TreeLSTMCell(_TreeLSTMCell):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__(h_size)
self._reduce_func = _TreeLSTMCellReduceFunc(
gluon.nn.Dense(3 * h_size, use_bias=False),
gluon.nn.Dense(2 * h_size))
self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
def reduce_func(self, nodes):
h, c = nodes.mailbox['h'], nodes.mailbox['c']
iou, c = self._reduce_func(h, c)
return {'iou': iou, 'c': c}
class ChildSumTreeLSTMCell(_TreeLSTMCell):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
self.U_iou = gluon.nn.Dense(3 * h_size, use_bias=False)
self.U_f = gluon.nn.Dense(h_size)
def reduce_func(self, nodes):
h_tild = nodes.mailbox['h'].sum(axis=1)
f = self.U_f(nodes.mailbox['h']).sigmoid()
c = (f * nodes.mailbox['c']).sum(axis=1)
return {'iou': self.U_iou(h_tild), 'c': c}
class TreeLSTM(gluon.nn.Block):
def __init__(self,
num_vocabs,
x_size,
h_size,
num_classes,
dropout,
cell_type='nary',
pretrained_emb=None,
ctx=None):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.embedding = gluon.nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
print('Using glove')
self.embedding.initialize(ctx=ctx)
self.embedding.weight.set_data(pretrained_emb)
self.dropout = gluon.nn.Dropout(dropout)
self.linear = gluon.nn.Dense(num_classes)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size)
def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
h : Tensor
Initial hidden state.
c : Tensor
Initial cell state.
Returns
-------
logits : Tensor
The prediction of each node.
"""
g = batch.graph
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.expand_dims(-1)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
dgl.prop_nodes_topo(g)
# compute logits
h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h)
return logits
......@@ -9,7 +9,7 @@ import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
from dgl.data.tree import SST
from dgl.data.tree import SST, SSTBatch
from tree_lstm import TreeLSTM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment