Commit 53b9a4bd authored by Zihao Ye's avatar Zihao Ye Committed by Minjie Wang
Browse files

[DOC][Model] Update Tree-LSTM (#152)

* update tree lstm

* tree_lstm (new interface)

* simplify pop

* merge qipeng(root)

* upd tree-lstm & tutorial

* upd model

* new capsule tutorial

* capsule for new API

* fix deprecated API

* New tutorial and example

* investigate gc problem

* add viz code

* new capsule tutorial

* remove ipynb

* move u_hat

* add link

* add requirements.txt

* remove ani.save

* update ci to install requirements

* utf-8

* change seed

* graphviz requirement

* accelerate

* little format

* update some markup
parent 2389df81
...@@ -47,6 +47,7 @@ def main(args): ...@@ -47,6 +47,7 @@ def main(args):
args.h_size, args.h_size,
trainset.num_classes, trainset.num_classes,
args.dropout, args.dropout,
cell_type='childsum' if args.child_sum else 'nary',
pretrained_emb = trainset.pretrained_emb).to(device) pretrained_emb = trainset.pretrained_emb).to(device)
print(model) print(model)
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs] params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs]
...@@ -55,6 +56,7 @@ def main(args): ...@@ -55,6 +56,7 @@ def main(args):
optimizer = optim.Adagrad([ optimizer = optim.Adagrad([
{'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay}, {'params':params_ex_emb, 'lr':args.lr, 'weight_decay':args.weight_decay},
{'params':params_emb, 'lr':0.1*args.lr}]) {'params':params_emb, 'lr':0.1*args.lr}])
dur = [] dur = []
for epoch in range(args.epochs): for epoch in range(args.epochs):
t_epoch = time.time() t_epoch = time.time()
...@@ -106,6 +108,7 @@ def main(args): ...@@ -106,6 +108,7 @@ def main(args):
root_accs.append([root_acc, len(root_ids)]) root_accs.append([root_acc, len(root_ids)])
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10 param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) dev_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) dev_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format( print("Epoch {:05d} | Dev Acc {:.4f} | Root Acc {:.4f}".format(
...@@ -132,6 +135,7 @@ def main(args): ...@@ -132,6 +135,7 @@ def main(args):
#lr decay #lr decay
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10 param_group['lr'] = max(1e-5, param_group['lr']*0.99) #10
test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs]) test_acc = 1.0*np.sum([x[0] for x in accs])/np.sum([x[1] for x in accs])
test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs]) test_root_acc = 1.0*np.sum([x[0] for x in root_accs])/np.sum([x[1] for x in root_accs])
print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format( print("Epoch {:05d} | Test Acc {:.4f} | Root Acc {:.4f}".format(
...@@ -142,6 +146,7 @@ if __name__ == '__main__': ...@@ -142,6 +146,7 @@ if __name__ == '__main__':
parser.add_argument('--gpu', type=int, default=-1) parser.add_argument('--gpu', type=int, default=-1)
parser.add_argument('--seed', type=int, default=12110) parser.add_argument('--seed', type=int, default=12110)
parser.add_argument('--batch-size', type=int, default=25) parser.add_argument('--batch-size', type=int, default=25)
parser.add_argument('--child-sum', action='store_true')
parser.add_argument('--x-size', type=int, default=300) parser.add_argument('--x-size', type=int, default=300)
parser.add_argument('--h-size', type=int, default=150) parser.add_argument('--h-size', type=int, default=150)
parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--epochs', type=int, default=100)
......
...@@ -35,6 +35,30 @@ class TreeLSTMCell(nn.Module): ...@@ -35,6 +35,30 @@ class TreeLSTMCell(nn.Module):
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(ChildSumTreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size)
self.U_iou = nn.Linear(h_size, 3 * h_size)
self.U_f = nn.Linear(h_size, h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes):
h_tild = th.sum(nodes.mailbox['h'], 1)
f = th.sigmoid(self.U_f(nodes.mailbox['h']))
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_tild), 'c': c}
def apply_node_func(self, nodes):
iou = nodes.data['iou']
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
h = o * th.tanh(c)
return {'h': h, 'c': c}
class TreeLSTM(nn.Module): class TreeLSTM(nn.Module):
def __init__(self, def __init__(self,
num_vocabs, num_vocabs,
...@@ -42,6 +66,7 @@ class TreeLSTM(nn.Module): ...@@ -42,6 +66,7 @@ class TreeLSTM(nn.Module):
h_size, h_size,
num_classes, num_classes,
dropout, dropout,
cell_type='nary',
pretrained_emb=None): pretrained_emb=None):
super(TreeLSTM, self).__init__() super(TreeLSTM, self).__init__()
self.x_size = x_size self.x_size = x_size
...@@ -52,7 +77,8 @@ class TreeLSTM(nn.Module): ...@@ -52,7 +77,8 @@ class TreeLSTM(nn.Module):
self.embedding.weight.requires_grad = True self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes) self.linear = nn.Linear(h_size, num_classes)
self.cell = TreeLSTMCell(x_size, h_size) cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size)
def forward(self, batch, h, c): def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
......
...@@ -10,22 +10,23 @@ Tree LSTM DGL Tutorial ...@@ -10,22 +10,23 @@ Tree LSTM DGL Tutorial
############################################################################## ##############################################################################
# #
# Tree-LSTM structure was first introduced by Kai et. al in their ACL 2015 # Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015
# paper: `Improved Semantic Representations From Tree-Structured Long # paper: `Improved Semantic Representations From Tree-Structured Long
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__, # Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__.
# aiming to introduce syntactic information in the network by extending # The core idea is to introduce syntactic information for language tasks by
# chain structured LSTM to tree structured LSTM, and uses Dependency # extending the chain-structured LSTM to a tree-structured LSTM. The Dependency
# Tree/Constituency Tree as the latent tree structure. # Tree/Constituency Tree techniques were leveraged to obtain a ''latent tree''.
# #
# The difficulty of training Tree-LSTM is that trees have different shape, # One, if not all, difficulty of training Tree-LSTMs is batching --- a standard
# making it difficult to parallelize. DGL offers a neat alternative. The # technique in machine learning to accelerate optimization. However, since trees
# key points are pooling all the trees into one graph, and then induce # generally have different shapes by nature, parallization becomes non trivial.
# message passing over them. # DGL offers an alternative: to pool all the trees into one single graph then
# induce the message passing over them guided by the structure of each tree.
# #
# The task and the dataset # The task and the dataset
# ------------------------ # ------------------------
# # In this tutorial, we will use Tree-LSTMs for sentiment analysis.
# We will use Tree-LSTM for sentiment analysis task. We have wrapped the # We have wrapped the
# `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in # `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in
# ``dgl.data``. The dataset provides a fine-grained tree level sentiment # ``dgl.data``. The dataset provides a fine-grained tree level sentiment
# annotation: 5 classes(very negative, negative, neutral, positive, and # annotation: 5 classes(very negative, negative, neutral, positive, and
...@@ -123,37 +124,38 @@ plot_tree(graph.to_networkx()) ...@@ -123,37 +124,38 @@ plot_tree(graph.to_networkx())
# Step 2: Tree-LSTM Cell with message-passing APIs # Step 2: Tree-LSTM Cell with message-passing APIs
# ------------------------------------------------ # ------------------------------------------------
# #
# .. note:: # The authors proposed two types of Tree LSTM: Child-Sum
# The paper proposed two types of Tree LSTM: Child-Sum # Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus on # on applying *Binary* Tree-LSTM to binarized constituency trees(this
# the later one. We use PyTorch as our backend framework to set up the # application is also known as *Constituency Tree-LSTM*). We use PyTorch
# network. # as our backend framework to set up the network.
# #
# In Tree LSTM, each unit at node :math:`j` maintains a hidden # In `N`-ary Tree LSTM, each unit at node :math:`j` maintains a hidden
# representation :math:`h_j` and a memory cell :math:`c_j`. The unit # representation :math:`h_j` and a memory cell :math:`c_j`. The unit
# :math:`j` takes the input vector :math:`x_j` and the hidden # :math:`j` takes the input vector :math:`x_j` and the hidden
# representations of the their child units: :math:`h_k, k\in C(j)` as # representations of the their child units: :math:`h_{jl}, 1\leq l\leq N` as
# input, then compute its new hidden representation :math:`h_j` and memory # input, then update its new hidden representation :math:`h_j` and memory
# cell :math:`c_j` in the following way. # cell :math:`c_j` by:
# #
# .. math:: # .. math::
# #
# i_j = \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), \\ # i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\
# f_{jk} = \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), \\ # f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\
# o_j = \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), \\ # o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\
# u_j = \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right) , \\ # u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\
# c_j = i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, \\ # c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\
# h_j = o_j \cdot \textrm{tanh}(c_j), \\ # h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\
# #
# The process can be decomposed into three phases: ``message_func``, # It can be decomposed into three phases: ``message_func``,
# ``reduce_func`` and ``apply_node_func``. # ``reduce_func`` and ``apply_node_func``.
# #
# ``apply_node_func`` is a new node UDF we have not introduced before. In # .. note::
# ``apply_node_func``, user specifies what to do with node features, # ``apply_node_func`` is a new node UDF we have not introduced before. In
# without considering edge features and messages. In Tree-LSTM case, # ``apply_node_func``, user specifies what to do with node features,
# ``apply_node_func`` is a must, since there exists (leaf) nodes with # without considering edge features and messages. In Tree-LSTM case,
# :math:`0` incoming edges, which would not be updated via # ``apply_node_func`` is a must, since there exists (leaf) nodes with
# ``reduce_func``. # :math:`0` incoming edges, which would not be updated via
# ``reduce_func``.
# #
import torch as th import torch as th
...@@ -170,16 +172,22 @@ class TreeLSTMCell(nn.Module): ...@@ -170,16 +172,22 @@ class TreeLSTMCell(nn.Module):
return {'h': edges.src['h'], 'c': edges.src['c']} return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes): def reduce_func(self, nodes):
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) # concatenate h_jl for equation (1), (2), (3), (4)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
c = th.sum(f * nodes.mailbox['c'], 1) # equation (2)
return {'iou': self.U_iou(h_cat), 'c': c} f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
# second term of equation (5)
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
# equation (1), (3), (4)
iou = nodes.data['iou'] iou = nodes.data['iou']
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c'] # equation (5)
c = i * u + nodes.data['c']
# equation (6)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
...@@ -213,13 +221,6 @@ print(dgl.topological_nodes_generator(graph)) ...@@ -213,13 +221,6 @@ print(dgl.topological_nodes_generator(graph))
############################################################################## ##############################################################################
# We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing: # We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
#
# .. note::
#
# Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a
# `message_func` and `reduce_func` in advance, here we use built-in
# copy-from-source and sum function as our message function and reduce
# function for demonstration.
import dgl.function as fn import dgl.function as fn
import torch as th import torch as th
...@@ -235,6 +236,13 @@ graph.prop_nodes(traversal_order) ...@@ -235,6 +236,13 @@ graph.prop_nodes(traversal_order)
# dgl.prop_nodes_topo(graph) # dgl.prop_nodes_topo(graph)
############################################################################## ##############################################################################
# .. note::
#
# Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a
# `message_func` and `reduce_func` in advance, here we use built-in
# copy-from-source and sum function as our message function and reduce
# function for demonstration.
#
# Putting it together # Putting it together
# ------------------- # -------------------
# #
...@@ -353,3 +361,4 @@ for epoch in range(epochs): ...@@ -353,3 +361,4 @@ for epoch in range(epochs):
# To train the model on full dataset with different settings(CPU/GPU, # To train the model on full dataset with different settings(CPU/GPU,
# etc.), please refer to our repo's # etc.), please refer to our repo's
# `example <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/tree_lstm>`__. # `example <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/tree_lstm>`__.
# Besides, we also provide an implementation of the Child-Sum Tree LSTM.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment