Commit 53b9a4bd authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[DOC][Model] Update Tree-LSTM (#152)

* update tree lstm

* tree_lstm (new interface)

* simplify pop

* merge qipeng(root)

* upd tree-lstm & tutorial

* upd model

* new capsule tutorial

* capsule for new API

* fix deprecated API

* New tutorial and example

* investigate gc problem

* add viz code

* new capsule tutorial

* remove ipynb

* move u_hat

* add link

* add requirements.txt

* remove ani.save

* update ci to install requirements

* utf-8

* change seed

* graphviz requirement

* accelerate

* little format

* update some markup
parent 2389df81
......@@ -47,6 +47,7 @@ def main(args):
args.h_size,
trainset.num_classes,
args.dropout,
cell_type='childsum' if args.child_sum else 'nary',
pretrained_emb = trainset.pretrained_emb).to(device)
print(model)
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs]
......@@ -55,6 +56,7 @@ def main(args):
optimizer = optim.Adagrad([
{'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay},
{'params':params_emb, 'lr':0.1*args.lr}])
dur = []
for epoch in range(args.epochs):
t_epoch = time.time()
......@@ -106,6 +108,7 @@ def main(args):
root_accs.append([root_acc, len(root_ids)])
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
......@@ -132,6 +135,7 @@ def main(args):
#lr decay
for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
......@@ -142,6 +146,7 @@ if __name__ == '__main__':
parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--seed', type=int, default=12110)
parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--child-sum', action='store_true')
parser.add_argument('--x-size', type=int, default=300)
parser.add_argument('--h-size', type=int, default=150)
parser.add_argument('--epochs', type=int, default=100)
......
......@@ -35,6 +35,30 @@ class TreeLSTMCell(nn.Module):
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(h_size, 3 * h_size)
self.U_f = nn.Linear(h_size, h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
h_tild = th.sum(nodes.mailbox['h'], 1)
f = th.sigmoid(self.U_f(nodes.mailbox['h']))
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_tild), 'c': c}
def apply_node_func(self, nodes):
iou = nodes.data['iou']
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
h = o * th.tanh(c)
return {'h': h, 'c': c}
class TreeLSTM(nn.Module):
def __init__(self,
num_vocabs,
......@@ -42,6 +66,7 @@ class TreeLSTM(nn.Module):
h_size,
num_classes,
dropout,
cell_type='nary',
pretrained_emb=None):
super(TreeLSTM, self).__init__()
self.x_size = x_size
......@@ -52,7 +77,8 @@ class TreeLSTM(nn.Module):
self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
self.cell = TreeLSTMCell(x_size, h_size)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size)
def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch.
......
......@@ -10,22 +10,23 @@ Tree LSTM DGL Tutorial
##############################################################################
#
# Tree-LSTM structure was first introduced by Kai et. al in their ACL 2015
# Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015
# paper: `Improved Semantic Representations From Tree-Structured Long
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__,
# aiming to introduce syntactic information in the network by extending
# chain structured LSTM to tree structured LSTM, and uses Dependency
# Tree/Constituency Tree as the latent tree structure.
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__.
# The core idea is to introduce syntactic information for language tasks by
# extending the chain-structured LSTM to a tree-structured LSTM. The Dependency
# Tree/Constituency Tree techniques were leveraged to obtain a ''latent tree''.
#
# The difficulty of training Tree-LSTM is that trees have different shape,
# making it difficult to parallelize. DGL offers a neat alternative. The
# key points are pooling all the trees into one graph, and then induce
# message passing over them.
# One, if not all, difficulty of training Tree-LSTMs is batching --- a standard
# technique in machine learning to accelerate optimization. However, since trees
# generally have different shapes by nature, parallization becomes non trivial.
# DGL offers an alternative: to pool all the trees into one single graph then
# induce the message passing over them guided by the structure of each tree.
#
# The task and the dataset
# ------------------------
#
# We will use Tree-LSTM for sentiment analysis task. We have wrapped the
# In this tutorial, we will use Tree-LSTMs for sentiment analysis.
# We have wrapped the
# `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in
# ``dgl.data``. The dataset provides a fine-grained tree level sentiment
# annotation: 5 classes(very negative, negative, neutral, positive, and
......@@ -123,31 +124,32 @@ plot_tree(graph.to_networkx())
# Step 2: Tree-LSTM Cell with message-passing APIs
# ------------------------------------------------
#
# .. note::
# The paper proposed two types of Tree LSTM: Child-Sum
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus on
# the later one. We use PyTorch as our backend framework to set up the
# network.
# The authors proposed two types of Tree LSTM: Child-Sum
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus
# on applying *Binary* Tree-LSTM to binarized constituency trees(this
# application is also known as *Constituency Tree-LSTM*). We use PyTorch
# as our backend framework to set up the network.
#
# In Tree LSTM, each unit at node :math:`j` maintains a hidden
# In `N`-ary Tree LSTM, each unit at node :math:`j` maintains a hidden
# representation :math:`h_j` and a memory cell :math:`c_j`. The unit
# :math:`j` takes the input vector :math:`x_j` and the hidden
# representations of the their child units: :math:`h_k, k\in C(j)` as
# input, then compute its new hidden representation :math:`h_j` and memory
# cell :math:`c_j` in the following way.
# representations of the their child units: :math:`h_{jl}, 1\leq l\leq N` as
# input, then update its new hidden representation :math:`h_j` and memory
# cell :math:`c_j` by:
#
# .. math::
#
# i_j = \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), \\
# f_{jk} = \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), \\
# o_j = \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), \\
# u_j = \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right) , \\
# c_j = i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, \\
# h_j = o_j \cdot \textrm{tanh}(c_j), \\
# i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\
# f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\
# o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\
# u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\
# c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\
# h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\
#
# The process can be decomposed into three phases: ``message_func``,
# It can be decomposed into three phases: ``message_func``,
# ``reduce_func`` and ``apply_node_func``.
#
# .. note::
# ``apply_node_func`` is a new node UDF we have not introduced before. In
# ``apply_node_func``, user specifies what to do with node features,
# without considering edge features and messages. In Tree-LSTM case,
......@@ -170,16 +172,22 @@ class TreeLSTMCell(nn.Module):
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
# equation (2)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
# second term of equation (5)
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes):
# equation (1), (3), (4)
iou = nodes.data['iou']
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5)
c = i * u + nodes.data['c']
# equation (6)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
......@@ -213,13 +221,6 @@ print(dgl.topological_nodes_generator(graph))
##############################################################################
# We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
#
# .. note::
#
# Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a
# `message_func` and `reduce_func` in advance, here we use built-in
# copy-from-source and sum function as our message function and reduce
# function for demonstration.
import dgl.function as fn
import torch as th
......@@ -235,6 +236,13 @@ graph.prop_nodes(traversal_order)
# dgl.prop_nodes_topo(graph)
##############################################################################
# .. note::
#
# Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a
# `message_func` and `reduce_func` in advance, here we use built-in
# copy-from-source and sum function as our message function and reduce
# function for demonstration.
#
# Putting it together
# -------------------
#
......@@ -353,3 +361,4 @@ for epoch in range(epochs):
# To train the model on full dataset with different settings(CPU/GPU,
# etc.), please refer to our repo's
# `example <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/tree_lstm>`__.
# Besides, we also provide an implementation of the Child-Sum Tree LSTM.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment