3_tree-lstm.py 14.3 KB
Newer Older
1
2
3
"""
.. _model-tree-lstm:

4
5
Tutorial: Tree-LSTM in DGL
==========================
6
7
8
9
10
11
12
13

**Author**: Zihao Ye, Qipeng Guo, `Minjie Wang
<https://jermainewang.github.io/>`_, `Jake Zhao
<https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang
"""
 
##############################################################################
#
14
15
16
17
# In this tutorial, you learn to use Tree-LSTM networks for sentiment analysis. 
# The Tree-LSTM is a generalization of long short-term memory (LSTM) networks to tree-structured network topologies.
# 
# The Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015 
18
19
20
# paper: `Improved Semantic Representations From Tree-Structured Long
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__.
# The core idea is to introduce syntactic information for language tasks by 
21
22
# extending the chain-structured LSTM to a tree-structured LSTM. The dependency 
# tree and constituency tree techniques are leveraged to obtain a ''latent tree''.
23
#
24
# The challenge in training Tree-LSTMs is batching --- a standard 
25
# technique in machine learning to accelerate optimization. However, since trees 
26
27
28
# generally have different shapes by nature, parallization is non-trivial. 
# DGL offers an alternative. Pool all the trees into one single graph then 
# induce the message passing over them, guided by the structure of each tree.
29
30
31
#
# The task and the dataset
# ------------------------
32
33
# 
# The steps here use the
34
# `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in
35
36
37
38
39
# ``dgl.data``. The dataset provides a fine-grained, tree-level sentiment
# annotation. There are five classes: Very negative, negative, neutral, positive, and
# very positive, which indicate the sentiment in the current subtree. Non-leaf
# nodes in a constituency tree do not contain words, so use a special
# ``PAD_WORD`` token to denote them. During training and inference
40
41
42
43
44
45
46
# their embeddings would be masked to all-zero.
#
# .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png
#    :alt: 
#
# The figure displays one sample of the SST dataset, which is a
# constituency parse tree with their nodes labeled with sentiment. To
47
48
# speed up things, build a tiny set with five sentences and take a look
# at the first one.
49
50
51
#

import dgl
52
from dgl.data.tree import SST
Da Zheng's avatar
Da Zheng committed
53
from dgl.data import SSTBatch
54
55

# Each sample in the dataset is a constituency tree. The leaf nodes
56
57
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
58
# label is stored in the "y" feature field.
59
trainset = SST(mode='tiny')  # the "tiny" set has only five trees
60
61
62
63
64
65
66
67
68
69
70
71
72
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes

vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word

a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():
    if token != trainset.PAD_WORD:
        print(inv_vocab[token], end=" ")

##############################################################################
73
# Step 1: Batching
74
75
# ----------------
#
76
# Add all the trees to one graph, using
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# the :func:`~dgl.batched_graph.batch` API.
#

import networkx as nx
import matplotlib.pyplot as plt

graph = dgl.batch(tiny_sst)
def plot_tree(g):
    # this plot requires pygraphviz package
    pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
    nx.draw(g, pos, with_labels=False, node_size=10,
            node_color=[[.5, .5, .5]], arrowsize=4)
    plt.show()

plot_tree(graph.to_networkx())

93
#################################################################################
94
# You can read more about the definition of :func:`~dgl.batch`, or
95
# skip ahead to the next step:
96
97
# .. note::
#
98
99
100
#    **Definition**: :func:`~dgl.batch` unions a list of :math:`B`
#      :class:`~dgl.DGLGraph`\ s and returns a :class:`~dgl.DGLGraph` of batch 
#      size :math:`B`. 
101
102
#    
#    - The union includes all the nodes,
103
#      edges, and their features. The order of nodes, edges, and features are
104
105
#      preserved. 
#     
106
#        - Given that you have :math:`V_i` nodes for graph
107
108
109
110
111
#          :math:`\mathcal{G}_i`, the node ID :math:`j` in graph
#          :math:`\mathcal{G}_i` correspond to node ID
#          :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph. 
#    
#        - Therefore, performing feature transformation and message passing on
112
#          the batched graph is equivalent to doing those
113
114
115
116
117
#          on all ``DGLGraph`` constituents in parallel. 
#
#    - Duplicate references to the same graph are
#      treated as deep copies; the nodes, edges, and features are duplicated,
#      and mutation on one reference does not affect the other. 
118
#    - The batched graph keeps track of the meta
119
120
121
#      information of the constituents so it can be
#      :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
#
122
# Step 2: Tree-LSTM cell with message-passing APIs
123
124
# ------------------------------------------------
#
125
126
127
128
129
# Researchers have proposed two types of Tree-LSTMs: Child-Sum
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial you focus 
# on applying *Binary* Tree-LSTM to binarized constituency trees. This 
# application is also known as *Constituency Tree-LSTM*. Use PyTorch 
# as a backend framework to set up the network.
130
#
131
# In `N`-ary Tree-LSTM, each unit at node :math:`j` maintains a hidden
132
133
# representation :math:`h_j` and a memory cell :math:`c_j`. The unit
# :math:`j` takes the input vector :math:`x_j` and the hidden
134
# representations of the child units: :math:`h_{jl}, 1\leq l\leq N` as
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# input, then update its new hidden representation :math:`h_j` and memory
# cell :math:`c_j` by: 
#
# .. math::
#
#    i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right),  & (1)\\
#    f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), &  (2)\\
#    o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3)  \\
#    u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\
#    c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\
#    h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6)  \\
#
# It can be decomposed into three phases: ``message_func``,
# ``reduce_func`` and ``apply_node_func``.
#
# .. note::
151
152
153
#    ``apply_node_func`` is a new node UDF that has not been introduced before. In
#    ``apply_node_func``, a user specifies what to do with node features,
#    without considering edge features and messages. In a Tree-LSTM case,
154
#    ``apply_node_func`` is a must, since there exists (leaf) nodes with
155
#    :math:`0` incoming edges, which would not be updated with 
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#    ``reduce_func``.
#

import torch as th
import torch.nn as nn

class TreeLSTMCell(nn.Module):
    def __init__(self, x_size, h_size):
        super(TreeLSTMCell, self).__init__()
        self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
        self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
        self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
        self.U_f = nn.Linear(2 * h_size, 2 * h_size)

    def message_func(self, edges):
        return {'h': edges.src['h'], 'c': edges.src['c']}

    def reduce_func(self, nodes):
        # concatenate h_jl for equation (1), (2), (3), (4)
        h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
        # equation (2)
        f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
        # second term of equation (5)
        c = th.sum(f * nodes.mailbox['c'], 1)
        return {'iou': self.U_iou(h_cat), 'c': c}

    def apply_node_func(self, nodes):
        # equation (1), (3), (4)
        iou = nodes.data['iou'] + self.b_iou
        i, o, u = th.chunk(iou, 3, 1)
        i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
        # equation (5)
        c = i * u + nodes.data['c']
        # equation (6)
        h = o * th.tanh(c)
        return {'h' : h, 'c' : c}

##############################################################################
194
# Step 3: Define traversal
195
196
# ------------------------
#
197
# After you define the message-passing functions, induce the
198
199
200
201
202
# right order to trigger them. This is a significant departure from models
# such as GCN, where all nodes are pulling messages from upstream ones
# *simultaneously*.
#
# In the case of Tree-LSTM, messages start from leaves of the tree, and
brett koonce's avatar
brett koonce committed
203
# propagate/processed upwards until they reach the roots. A visualization
204
205
206
207
208
209
210
211
212
213
214
# is as follows:
#
# .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif
#    :alt:
#
# DGL defines a generator to perform the topological sort, each item is a
# tensor recording the nodes from bottom level to the roots. One can
# appreciate the degree of parallelism by inspecting the difference of the
# followings:
#

215
216
# to heterogenous graph
trv_a_tree = dgl.graph(a_tree.edges())
217
print('Traversing one tree:')
218
print(dgl.topological_nodes_generator(trv_a_tree))
219

220
221
# to heterogenous graph
trv_graph = dgl.graph(graph.edges())
222
print('Traversing many trees at the same time:')
223
print(dgl.topological_nodes_generator(trv_graph))
224
225

##############################################################################
226
# Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
227
228
229
230

import dgl.function as fn
import torch as th

231
232
233
234
235
trv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
traversal_order = dgl.topological_nodes_generator(trv_graph)
trv_graph.prop_nodes(traversal_order,
                     message_func=fn.copy_src('a', 'a'),
                     reduce_func=fn.sum('a', 'a'))
236
237
238
239
240
241
242

# the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph)

##############################################################################
# .. note::
#
243
244
245
#    Before you call :meth:`~dgl.DGLGraph.prop_nodes`, specify a
#    `message_func` and `reduce_func` in advance. In the example, you can see built-in
#    copy-from-source and sum functions as message functions, and a reduce
246
247
248
249
250
#    function for demonstration.
#
# Putting it together
# -------------------
#
251
# Here is the complete code that specifies the ``Tree-LSTM`` class.
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
#

class TreeLSTM(nn.Module):
    def __init__(self,
                 num_vocabs,
                 x_size,
                 h_size,
                 num_classes,
                 dropout,
                 pretrained_emb=None):
        super(TreeLSTM, self).__init__()
        self.x_size = x_size
        self.embedding = nn.Embedding(num_vocabs, x_size)
        if pretrained_emb is not None:
            print('Using glove')
            self.embedding.weight.data.copy_(pretrained_emb)
            self.embedding.weight.requires_grad = True
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(h_size, num_classes)
        self.cell = TreeLSTMCell(x_size, h_size)

    def forward(self, batch, h, c):
        """Compute tree-lstm prediction given a batch.

        Parameters
        ----------
        batch : dgl.data.SSTBatch
            The data batch.
        h : Tensor
            Initial hidden state.
        c : Tensor
            Initial cell state.

        Returns
        -------
        logits : Tensor
            The prediction of each node.
        """
        g = batch.graph
291
292
        # to heterogenous graph
        g = dgl.graph(g.edges())
293
294
295
296
297
298
        # feed embedding
        embeds = self.embedding(batch.wordid * batch.mask)
        g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
        g.ndata['h'] = h
        g.ndata['c'] = c
        # propagate
299
300
301
302
        dgl.prop_nodes_topo(g,
                            message_func=self.cell.message_func,
                            reduce_func=self.cell.reduce_func,
                            apply_node_func=self.cell.apply_node_func)
303
304
305
306
307
308
309
310
311
        # compute logits
        h = self.dropout(g.ndata.pop('h'))
        logits = self.linear(h)
        return logits

##############################################################################
# Main Loop
# ---------
#
312
# Finally, you could write a training paradigm in PyTorch.
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
#

from torch.utils.data import DataLoader
import torch.nn.functional as F

device = th.device('cpu')
# hyper parameters
x_size = 256
h_size = 256
dropout = 0.5
lr = 0.05
weight_decay = 1e-4
epochs = 10

# create the model
model = TreeLSTM(trainset.num_vocabs,
                 x_size,
                 h_size,
                 trainset.num_classes,
                 dropout)
print(model)

# create the optimizer
optimizer = th.optim.Adagrad(model.parameters(),
                          lr=lr,
                          weight_decay=weight_decay)
Da Zheng's avatar
Da Zheng committed
339
340
341
342
343
344
345
346
347
348

def batcher(dev):
    def batcher_dev(batch):
        batch_trees = dgl.batch(batch)
        return SSTBatch(graph=batch_trees,
                        mask=batch_trees.ndata['mask'].to(device),
                        wordid=batch_trees.ndata['x'].to(device),
                        label=batch_trees.ndata['y'].to(device))
    return batcher_dev

349
350
train_loader = DataLoader(dataset=tiny_sst,
                          batch_size=5,
Da Zheng's avatar
Da Zheng committed
351
                          collate_fn=batcher(device),
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
                          shuffle=False,
                          num_workers=0)

# training loop
for epoch in range(epochs):
    for step, batch in enumerate(train_loader):
        g = batch.graph
        n = g.number_of_nodes()
        h = th.zeros((n, h_size))
        c = th.zeros((n, h_size))
        logits = model(batch, h, c)
        logp = F.log_softmax(logits, 1)
        loss = F.nll_loss(logp, batch.label, reduction='sum') 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = th.argmax(logits, 1)
        acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
        print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
            epoch, step, loss.item(), acc))

##############################################################################
374
375
376
# To train the model on a full dataset with different settings (such as CPU or GPU),
# refer to the `PyTorch example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm>`__.
# There is also an implementation of the Child-Sum Tree-LSTM.