Commit a9ffb59e authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[DOC][Model] Tree-LSTM example and tutorial (#146)

[DOC][Model] Tree-LSTM example and tutorial
parent 0452cc3c
# Tree-LSTM
This is a re-implementation of the following paper:
> [**Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks**](http://arxiv.org/abs/1503.00075)
> *Kai Sheng Tai, Richard Socher, and Christopher Manning*.
The provided implementation can achieve a test accuracy of 50.59 which is comparable with the result reported in the paper 51.0.
## Data
The script will download the [SST dataset] (http://nlp.stanford.edu/sentiment/index.html) automatically, and you need to download the GloVe word vectors yourself. For the command line, you can use this.
```
wget http://nlp.stanford.edu/data/glove.840B.300d.zip
unzip glove.840B.300d.zip
```
## Usage
```
python train.py --gpu 0
```
## Speed Test
To enable fair comparison with [DyNet Tree-LSTM implementation](https://github.com/clab/dynet/tree/master/examples/treelstm), we set the batch size to 100.
```
python train.py --gpu 0 --batch-size 100
```
| Device | Framework | Speed(time per batch) |
|---------------------|-----------|-----------------------|
| GeForce GTX TITAN X | DGL | 7.23(±0.66)s |
......@@ -3,102 +3,152 @@ import time
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.nn.init as INIT
import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
import dgl.data as data
import dgl.ndarray as nd
from tree_lstm import TreeLSTM
def main(args):
np.random.seed(args.seed)
th.manual_seed(args.seed)
th.cuda.manual_seed(args.seed)
cuda = args.gpu >= 0
device = th.device('cuda:{}'.format(args.gpu)) if cuda else th.device('cpu')
if cuda:
th.cuda.set_device(args.gpu)
def _batcher(trees):
bg = dgl.batch(trees)
if cuda:
for key in bg.node_attr_schemes().keys():
bg.ndata[key] = bg.ndata[key].cuda()
return bg
trainset = data.SST()
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
collate_fn=_batcher,
collate_fn=data.SST.batcher(device),
shuffle=True,
num_workers=0)
devset = data.SST(mode='dev')
dev_loader = DataLoader(dataset=devset,
batch_size=100,
collate_fn=data.SST.batcher(device),
shuffle=False,
num_workers=0)
testset = data.SST(mode='test')
test_loader = DataLoader(dataset=testset,
batch_size=100,
collate_fn=data.SST.batcher(device),
shuffle=False,
num_workers=0)
#testset = data.SST(mode='test')
#test_loader = DataLoader(dataset=testset,
# batch_size=100,
# collate_fn=data.SST.batcher,
# shuffle=False,
# num_workers=0)
model = TreeLSTM(trainset.num_vocabs,
args.x_size,
args.h_size,
trainset.num_classes,
args.dropout)
if cuda:
model.cuda()
zero_initializer = lambda shape : th.zeros(shape).cuda()
else:
zero_initializer = th.zeros
args.dropout,
pretrained_emb = trainset.pretrained_emb).to(device)
print(model)
optimizer = optim.Adagrad(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs]
params_emb = list(model.embedding.parameters())
optimizer = optim.Adagrad([
{'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay},
{'params':params_emb, 'lr':0.1*args.lr}])
dur = []
for epoch in range(args.epochs):
t_epoch = time.time()
for step, graph in enumerate(train_loader):
model.train()
for step, batch in enumerate(train_loader):
g = batch.graph
n = g.number_of_nodes()
h = th.zeros((n, args.h_size)).to(device)
c = th.zeros((n, args.h_size)).to(device)
if step >= 3:
t0 = time.time()
label = graph.ndata.pop('y')
# traverse graph
logits = model(graph, zero_initializer, train=True)
t0 = time.time() # tik
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, label)
loss = F.nll_loss(logp, batch.label, reduction='elementwise_mean')
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step >= 3:
dur.append(time.time() - t0)
dur.append(time.time() - t0) # tok
if step > 0 and step % args.log_every == 0:
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(label, pred))
mean_dur = np.mean(dur)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
"Acc {:.4f} | Time(s) {:.4f} | Trees/s {:.4f}".format(
epoch, step, loss.item(), acc.item() / len(label),
mean_dur, args.batch_size / mean_dur))
print("Epoch time(s):", time.time() - t_epoch)
acc = th.sum(th.eq(batch.label, pred))
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} | Root Acc {:.4f} | Time(s) {:.4f}".format(
epoch, step, loss.item(), 1.0*acc.item()/len(batch.label), 1.0*root_acc/len(root_ids), np.mean(dur)))
print('Epoch {:05d} training time {:.4f}s'.format(epoch, time.time() - t_epoch))
# test on dev set
accs = []
root_accs = []
model.eval()
for step, batch in enumerate(dev_loader):
g = batch.graph
n = g.number_of_nodes()
with th.no_grad():
h = th.zeros((n, args.h_size)).to(device)
c = th.zeros((n, args.h_size)).to(device)
logits = model(batch, h, c)
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
epoch, dev_acc, dev_root_acc))
# test
#for step, batch in enumerate(test_loader):
# g = batch.graph
# n = g.number_of_nodes()
# x = th.zeros((n, args.x_size))
# h = th.zeros((n, args.h_size))
# c = th.zeros((n, args.h_size))
# logits = model(batch, x, h, c, train=True)
# pred = th.argmax(logits, 1)
# acc = th.sum(th.eq(batch.label, pred)) / len(batch.label)
# print(acc.item())
accs = []
root_accs = []
model.eval()
for step, batch in enumerate(test_loader):
g = batch.graph
n = g.number_of_nodes()
with th.no_grad():
h = th.zeros((n, args.h_size)).to(device)
c = th.zeros((n, args.h_size)).to(device)
logits = model(batch, h, c)
pred = th.argmax(logits, 1)
acc = th.sum(th.eq(batch.label, pred)).item()
accs.append([acc, len(batch.label)])
root_ids = [i for i in range(batch.graph.number_of_nodes()) if batch.graph.out_degree(i)==0]
root_acc = np.sum(batch.label.cpu().data.numpy()[root_ids] == pred.cpu().data.numpy()[root_ids])
root_accs.append([root_acc, len(root_ids)])
#lr decay
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
epoch, test_acc, test_root_acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--seed', type=int, default=12110)
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--x-size', type=int, default=256)
parser.add_argument('--h-size', type=int, default=256)
parser.add_argument('--x-size', type=int, default=300)
parser.add_argument('--h-size', type=int, default=150)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--log-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.05)
parser.add_argument('--n-ary', type=int, default=2)
parser.add_argument('--weight-decay', type=float, default=1e-4)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--dropout', type=float, default=0.3)
args = parser.parse_args()
print(args)
main(args)
......@@ -9,41 +9,29 @@ import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
class ChildSumTreeLSTMCell(nn.Module):
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(h_size, 3 * h_size)
self.W_f = nn.Linear(x_size, h_size)
self.U_f = nn.Linear(h_size, h_size)
self.rt = 0.
self.ut = 0.
self.U_iou = nn.Linear(2 * h_size, 3 * h_size)
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
return {'h' : edges.src['h'], 'c' : edges.src['c']}
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
# equation (2)
h_tild = th.sum(nodes.mailbox['h'], 1)
# equation (4)
wx = self.W_f(nodes.data['x']).unsqueeze(1) # shape: (B, 1, H)
uh = self.U_f(nodes.mailbox['h']) # shape: (B, deg, H)
f = th.sigmoid(wx + uh) # shape: (B, deg, H)
# equation (7) second term
c_tild = th.sum(f * nodes.mailbox['c'], 1)
return {'h_tild' : h_tild, 'c_tild' : c_tild}
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
def apply_func(self, nodes):
# equation (3), (5), (6)
iou = self.W_iou(nodes.data['x']) + self.U_iou(nodes.data['h_tild'])
def apply_node_func(self, nodes):
iou = nodes.data['iou']
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (7)
c = i * u + nodes.data['c_tild']
# equation (8)
c = i * u + nodes.data['c']
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
......@@ -54,105 +42,47 @@ class TreeLSTM(nn.Module):
h_size,
num_classes,
dropout,
cell_type='childsum'):
pretrained_emb=None):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.h_size = h_size
# TODO(minjie): pre-trained embedding like GLoVe
self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
print('Using glove')
self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
if cell_type == 'childsum':
self.cell = ChildSumTreeLSTMCell(x_size, h_size)
else:
raise RuntimeError('Unknown cell type:', cell_type)
self.cell = TreeLSTMCell(x_size, h_size)
def forward(self, graph, zero_initializer, h=None, c=None, train=True):
def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
graph : dgl.DGLGraph
The batched trees.
zero_initializer : callable
Function to return zero value tensor.
h : Tensor, optional
batch : dgl.data.SSTBatch
The data batch.
h : Tensor
Initial hidden state.
c : Tensor, optional
c : Tensor
Initial cell state.
iterator : graph iterator, optional
External iterator on graph.
Returns
-------
logits : Tensor
The prediction of each node.
"""
g = graph
n = g.number_of_nodes()
g = batch.graph
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_func)
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
wordid = g.pop_n_repr('x')
mask = (wordid != dgl.data.SST.PAD_WORD)
wordid = wordid * mask.long()
embeds = self.embedding(wordid)
g.ndata['x'] = embeds * th.unsqueeze(mask, 1).float()
if h is None:
h = zero_initializer((n, self.h_size))
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(embeds) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['h_tild'] = zero_initializer((n, self.h_size))
if c is None:
c = zero_initializer((n, self.h_size))
g.ndata['c'] = c
g.ndata['c_tild'] = zero_initializer((n, self.h_size))
# propagate
dgl.prop_nodes_topo(g)
# compute logits
h = g.ndata.pop('h')
h = self.dropout(h)
h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h)
return logits
'''
class NAryTreeLSTM(TreeLSTM):
def __init__(self, n_embeddings, x_size, h_size, n_ary, n_classes):
super().__init__(n_embeddings, x_size, h_size, n_classes)
# TODO initializer
self.iou_w = nn.Parameter(th.randn(x_size, 3 * h_size))
self.iou_u = [nn.Parameter(th.randn(1, h_size, 3 * h_size)) for i in range(n_ary)]
self.iou_b = nn.Parameter(th.zeros(1, 3 * h_size))
# TODO initializer
self.f_x = nn.Parameter(th.randn(x_size, h_size))
self.f_h = [[nn.Parameter(th.randn(1, h_size, h_size))
for i in range(n_ary)] for i in range(n_ary)]
self.f_b = nn.Parameter(th.zeros(1, h_size))
def internal_update_func(self, node_reprs, edge_reprs):
assert len(edge_reprs) > 0
assert all(msg['h'] is not None and msg['c'] is not None for msg in edge_reprs)
x = node_reprs['x']
n_children = len(edge_reprs)
iou_wx = th.mm(x, self.iou_w) if x is not None else 0
iou_u = th.cat(self.iou_u[:n_children], 0)
iou_h = th.cat([msg['h'] for msg in edge_reprs], 0).unsqueeze(1)
iou_uh = th.sum(th.bmm(iou_h, iou_u), 0)
i, o, u = th.chunk(iou_wx + iou_uh + self.iou_b, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
f_wx = th.mm(x, self.f_x).repeat(n_children, 1) if x is not None else 0
f_h = iou_h.repeat(n_children, 1, 1)
f_u = th.cat(sum([self.f_h[i][:n_children] for i in range(n_children)], []), 0)
f_uh = th.sum(th.bmm(f_h, f_u).view(n_children, n_children, -1), 0)
f = th.sigmoid(f_wx + f_uh + self.f_b)
c = th.cat([msg['c'] for msg in edge_reprs], 0)
c = i * u + th.sum(f * c, 0, keepdim=True)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
'''
......@@ -5,19 +5,23 @@ Including:
"""
from __future__ import absolute_import
from collections import namedtuple
from collections import namedtuple, OrderedDict
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader
import networkx as nx
from .. import backend as F
from ..graph import DGLGraph
from .utils import download, extract_archive, get_download_dir
import numpy as np
import os
import dgl
import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir
_urls = {
'sst' : 'https://www.dropbox.com/s/dw8kr2vuq7k4dqi/sst.zip?dl=1',
'sst' : 'https://www.dropbox.com/s/6qa8rm43r2nmbyw/sst.zip?dl=1',
}
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
class SST(object):
"""Stanford Sentiment Treebank dataset.
......@@ -42,10 +46,13 @@ class SST(object):
Optional vocabulary file.
"""
PAD_WORD=-1 # special pad word id
UNK_WORD=-1 # out-of-vocabulary word id
def __init__(self, mode='train', vocab_file=None):
self.mode = mode
self.dir = get_download_dir()
self.zip_file_path='{}/sst.zip'.format(self.dir)
self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else ''
self.pretrained_emb = None
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
download(_urls['sst'], path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
......@@ -56,37 +63,60 @@ class SST(object):
print('Dataset creation finished. #Trees:', len(self.trees))
def _load(self):
files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader('{}/sst'.format(self.dir), files)
sents = corpus.parsed_sents(files[0])
# load vocab file
self.vocab = {}
with open(self.vocab_file) as vf:
self.vocab = OrderedDict()
with open(self.vocab_file, encoding='utf-8') as vf:
for line in vf.readlines():
line = line.strip()
self.vocab[line] = len(self.vocab)
# filter glove
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
glove_emb = {}
with open(self.pretrained_file, 'r', encoding='utf-8') as pf:
for line in pf.readlines():
sp = line.split(' ')
if sp[0].lower() in self.vocab:
glove_emb[sp[0].lower()] = np.array([float(x) for x in sp[1:]])
files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader('{}/sst'.format(self.dir), files)
sents = corpus.parsed_sents(files[0])
#initialize with glove
pretrained_emb = []
fail_cnt = 0
for line in self.vocab.keys():
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
if not line.lower() in glove_emb:
fail_cnt += 1
pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300)))
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
self.pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
print('Miss word in GloVe {0:.4f}'.format(1.0*fail_cnt/len(self.pretrained_emb)))
# build trees
for sent in sents:
self.trees.append(self._build_tree(sent))
def _build_tree(self, root):
g = nx.DiGraph()
def _rec_build(nid, node):
for child in node:
cid = g.number_of_nodes()
if isinstance(child[0], str):
if isinstance(child[0], str) or isinstance(child[0], bytes):
# leaf node
word = self.vocab[child[0].lower()]
g.add_node(cid, x=word, y=int(child.label()))
word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label()), mask=(word!=self.UNK_WORD))
else:
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()))
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()), mask=0)
_rec_build(cid, child)
g.add_edge(cid, nid)
# add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()))
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root)
ret = DGLGraph()
ret.from_networkx(g, node_attrs=['x', 'y'])
ret = dgl.DGLGraph()
ret.from_networkx(g, node_attrs=['x', 'y', 'mask'])
return ret
def __getitem__(self, idx):
......@@ -98,3 +128,13 @@ class SST(object):
@property
def num_vocabs(self):
return len(self.vocab)
@staticmethod
def batcher(device):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'].to(device))
return batcher_dev
"""
.. _model-tree-lstm:
Tree LSTM DGL Tutorial
=========================
**Author**: `Zihao Ye`, `Qipeng Guo`, `Minjie Wang`, `Zheng Zhang`
"""
##############################################################################
#
# Tree-LSTM structure was first introduced by Kai et. al in their ACL 2015
# paper: `Improved Semantic Representations From Tree-Structured Long
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__,
# aiming to introduce syntactic information in the network by extending
# chain structured LSTM to tree structured LSTM, and uses Dependency
# Tree/Constituency Tree as the latent tree structure.
#
# The difficulty of training Tree-LSTM is that trees have different shape,
# making it difficult to parallelize. DGL offers a neat alternative. The
# key points are pooling all the trees into one graph, and then induce
# message passing over them.
#
# The task and the dataset
# ------------------------
#
# We will use Tree-LSTM for sentiment analysis task. We have wrapped the
# `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in
# ``dgl.data``. The dataset provides a fine-grained tree level sentiment
# annotation: 5 classes(very negative, negative, neutral, positive, and
# very positive) that indicates the sentiment in current subtree. Non-leaf
# nodes in constituency tree does not contain words, we use a special
# ``PAD_WORD`` token to denote them, during the training/inferencing,
# their embeddings would be masked to all-zero.
#
# .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png
# :alt:
#
# The figure displays one sample of the SST dataset, which is a
# constituency parse tree with their nodes labeled with sentiment. To
# speed up things, let's build a tiny set with 5 sentences and take a look
# at the first one:
#
import dgl
import dgl.data as data
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is a int value stored in the "x" field.
# The non-leaf nodes has a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = data.SST(mode='tiny') # the "tiny" set has only 5 trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes
vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word
a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():
if token != trainset.PAD_WORD:
print(inv_vocab[token], end=" ")
##############################################################################
# Step 1: batching
# ----------------
#
# The first step is to throw all the trees into one graph, using
# the :func:`~dgl.batched_graph.batch` API.
#
import networkx as nx
import matplotlib.pyplot as plt
graph = dgl.batch(tiny_sst)
def plot_tree(g):
# this plot requires pygraphviz package
pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
nx.draw(g, pos, with_labels=False, node_size=10,
node_color=[[.5, .5, .5]], arrowsize=4)
plt.show()
plot_tree(graph.to_networkx())
##############################################################################
# You can read more about the definition of :func:`~dgl.batched_graph.batch`
# (by clicking the API), or can skip ahead to the next step:
#
# .. note::
#
# **Definition**: a :class:`~dgl.batched_graph.BatchedDGLGraph` is a
# :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s.
#
# - The union includes all the nodes,
# edges, and their features. The order of nodes, edges and features are
# preserved.
#
# - Given that we have :math:`V_i` nodes for graph
# :math:`\mathcal{G}_i`, the node ID :math:`j` in graph
# :math:`\mathcal{G}_i` correspond to node ID
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
#
# - Therefore, performing feature transformation and message passing on
# ``BatchedDGLGraph`` is equivalent to doing those
# on all ``DGLGraph`` constituents in parallel.
#
# - Duplicate references to the same graph are
# treated as deep copies; the nodes, edges, and features are duplicated,
# and mutation on one reference does not affect the other.
# - Currently, ``BatchedDGLGraph`` is immutable in
# graph structure (i.e. one can't add
# nodes and edges to it). We need to support mutable batched graphs in
# (far) future.
# - The ``BatchedDGLGraph`` keeps track of the meta
# information of the constituents so it can be
# :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
#
# For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`
# module in DGL, you can click the class name.
#
# Step 2: Tree-LSTM Cell with message-passing APIs
# ------------------------------------------------
#
# .. note::
# The paper proposed two types of Tree LSTM: Child-Sum
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus on
# the later one. We use PyTorch as our backend framework to set up the
# network.
#
# In Tree LSTM, each unit at node :math:`j` maintains a hidden
# representation :math:`h_j` and a memory cell :math:`c_j`. The unit
# :math:`j` takes the input vector :math:`x_j` and the hidden
# representations of the their child units: :math:`h_k, k\in C(j)` as
# input, then compute its new hidden representation :math:`h_j` and memory
# cell :math:`c_j` in the following way.
#
# .. math::
#
# i_j = \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), \\
# f_{jk} = \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), \\
# o_j = \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), \\
# u_j = \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right) , \\
# c_j = i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, \\
# h_j = o_j \cdot \textrm{tanh}(c_j), \\
#
# The process can be decomposed into three phases: ``message_func``,
# ``reduce_func`` and ``apply_node_func``.
#
# ``apply_node_func`` is a new node UDF we have not introduced before. In
# ``apply_node_func``, user specifies what to do with node features,
# without considering edge features and messages. In Tree-LSTM case,
# ``apply_node_func`` is a must, since there exists (leaf) nodes with
# :math:`0` incoming edges, which would not be updated via
# ``reduce_func``.
#
import torch as th
import torch.nn as nn
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size)
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes):
iou = nodes.data['iou']
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
##############################################################################
# Step 3: define traversal
# ------------------------
#
# After defining the message passing functions, we then need to induce the
# right order to trigger them. This is a significant departure from models
# such as GCN, where all nodes are pulling messages from upstream ones
# *simultaneously*.
#
# In the case of Tree-LSTM, messages start from leaves of the tree, and
# propogate/processed upwards until they reach the roots. A visulization
# is as follows:
#
# .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif
# :alt:
#
# DGL defines a generator to perform the topological sort, each item is a
# tensor recording the nodes from bottom level to the roots. One can
# appreciate the degree of parallelism by inspecting the difference of the
# followings:
#
print('Traversing one tree:')
print(dgl.topological_nodes_generator(a_tree))
print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph))
##############################################################################
# We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
#
# .. note::
#
# Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a
# `message_func` and `reduce_func` in advance, here we use built-in
# copy-from-source and sum function as our message function and reduce
# function for demonstration.
import dgl.function as fn
import torch as th
graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
graph.register_message_func(fn.copy_src('a', 'a'))
graph.register_reduce_func(fn.sum('a', 'a'))
traversal_order = dgl.topological_nodes_generator(graph)
graph.prop_nodes(traversal_order)
# the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph)
##############################################################################
# Putting it together
# -------------------
#
# Here is the complete code that specifies the ``Tree-LSTM`` class:
#
class TreeLSTM(nn.Module):
def __init__(self,
num_vocabs,
x_size,
h_size,
num_classes,
dropout,
pretrained_emb=None):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
print('Using glove')
self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
self.cell = TreeLSTMCell(x_size, h_size)
def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
Parameters
----------
batch : dgl.data.SSTBatch
The data batch.
h : Tensor
Initial hidden state.
c : Tensor
Initial cell state.
Returns
-------
logits : Tensor
The prediction of each node.
"""
g = batch.graph
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(embeds) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
dgl.prop_nodes_topo(g)
# compute logits
h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h)
return logits
##############################################################################
# Main Loop
# ---------
#
# Finally, we could write a training paradigm in PyTorch:
#
from torch.utils.data import DataLoader
import torch.nn.functional as F
device = th.device('cpu')
# hyper parameters
x_size = 256
h_size = 256
dropout = 0.5
lr = 0.05
weight_decay = 1e-4
epochs = 10
# create the model
model = TreeLSTM(trainset.num_vocabs,
x_size,
h_size,
trainset.num_classes,
dropout)
print(model)
# create the optimizer
optimizer = th.optim.Adagrad(model.parameters(),
lr=lr,
weight_decay=weight_decay)
train_loader = DataLoader(dataset=tiny_sst,
batch_size=5,
collate_fn=data.SST.batcher(device),
shuffle=False,
num_workers=0)
# training loop
for epoch in range(epochs):
for step, batch in enumerate(train_loader):
g = batch.graph
n = g.number_of_nodes()
h = th.zeros((n, h_size))
c = th.zeros((n, h_size))
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='elementwise_mean')
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = th.argmax(logits, 1)
acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
epoch, step, loss.item(), acc))
##############################################################################
# To train the model on full dataset with different settings(CPU/GPU,
# etc.), please refer to our repo's
# `example <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/tree_lstm>`__.
......@@ -4,3 +4,4 @@ numpy
seaborn
matplotlib
pygraphviz
graphviz
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment