""" .. _model-tree-lstm: Tree LSTM DGL Tutorial ========================= **Author**: Zihao Ye, Qipeng Guo, `Minjie Wang `_, `Jake Zhao `_, Zheng Zhang """ ############################################################################## # # Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015 # paper: `Improved Semantic Representations From Tree-Structured Long # Short-Term Memory Networks `__. # The core idea is to introduce syntactic information for language tasks by # extending the chain-structured LSTM to a tree-structured LSTM. The Dependency # Tree/Constituency Tree techniques were leveraged to obtain a ''latent tree''. # # One, if not all, difficulty of training Tree-LSTMs is batching --- a standard # technique in machine learning to accelerate optimization. However, since trees # generally have different shapes by nature, parallization becomes non trivial. # DGL offers an alternative: to pool all the trees into one single graph then # induce the message passing over them guided by the structure of each tree. # # The task and the dataset # ------------------------ # In this tutorial, we will use Tree-LSTMs for sentiment analysis. # We have wrapped the # `Stanford Sentiment Treebank `__ in # ``dgl.data``. The dataset provides a fine-grained tree level sentiment # annotation: 5 classes(very negative, negative, neutral, positive, and # very positive) that indicates the sentiment in current subtree. Non-leaf # nodes in constituency tree does not contain words, we use a special # ``PAD_WORD`` token to denote them, during the training/inferencing, # their embeddings would be masked to all-zero. # # .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png # :alt: # # The figure displays one sample of the SST dataset, which is a # constituency parse tree with their nodes labeled with sentiment. To # speed up things, let's build a tiny set with 5 sentences and take a look # at the first one: # import dgl from dgl.data.tree import SST # Each sample in the dataset is a constituency tree. The leaf nodes # represent words. The word is a int value stored in the "x" field. # The non-leaf nodes has a special word PAD_WORD. The sentiment # label is stored in the "y" feature field. trainset = SST(mode='tiny') # the "tiny" set has only 5 trees tiny_sst = trainset.trees num_vocabs = trainset.num_vocabs num_classes = trainset.num_classes vocab = trainset.vocab # vocabulary dict: key -> id inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word a_tree = tiny_sst[0] for token in a_tree.ndata['x'].tolist(): if token != trainset.PAD_WORD: print(inv_vocab[token], end=" ") ############################################################################## # Step 1: batching # ---------------- # # The first step is to throw all the trees into one graph, using # the :func:`~dgl.batched_graph.batch` API. # import networkx as nx import matplotlib.pyplot as plt graph = dgl.batch(tiny_sst) def plot_tree(g): # this plot requires pygraphviz package pos = nx.nx_agraph.graphviz_layout(g, prog='dot') nx.draw(g, pos, with_labels=False, node_size=10, node_color=[[.5, .5, .5]], arrowsize=4) plt.show() plot_tree(graph.to_networkx()) ############################################################################## # You can read more about the definition of :func:`~dgl.batched_graph.batch` # (by clicking the API), or can skip ahead to the next step: # # .. note:: # # **Definition**: a :class:`~dgl.batched_graph.BatchedDGLGraph` is a # :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s. # # - The union includes all the nodes, # edges, and their features. The order of nodes, edges and features are # preserved. # # - Given that we have :math:`V_i` nodes for graph # :math:`\mathcal{G}_i`, the node ID :math:`j` in graph # :math:`\mathcal{G}_i` correspond to node ID # :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph. # # - Therefore, performing feature transformation and message passing on # ``BatchedDGLGraph`` is equivalent to doing those # on all ``DGLGraph`` constituents in parallel. # # - Duplicate references to the same graph are # treated as deep copies; the nodes, edges, and features are duplicated, # and mutation on one reference does not affect the other. # - Currently, ``BatchedDGLGraph`` is immutable in # graph structure (i.e. one can't add # nodes and edges to it). We need to support mutable batched graphs in # (far) future. # - The ``BatchedDGLGraph`` keeps track of the meta # information of the constituents so it can be # :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s. # # For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph` # module in DGL, you can click the class name. # # Step 2: Tree-LSTM Cell with message-passing APIs # ------------------------------------------------ # # The authors proposed two types of Tree LSTM: Child-Sum # Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus # on applying *Binary* Tree-LSTM to binarized constituency trees(this # application is also known as *Constituency Tree-LSTM*). We use PyTorch # as our backend framework to set up the network. # # In `N`-ary Tree LSTM, each unit at node :math:`j` maintains a hidden # representation :math:`h_j` and a memory cell :math:`c_j`. The unit # :math:`j` takes the input vector :math:`x_j` and the hidden # representations of the their child units: :math:`h_{jl}, 1\leq l\leq N` as # input, then update its new hidden representation :math:`h_j` and memory # cell :math:`c_j` by: # # .. math:: # # i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\ # f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\ # o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\ # u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\ # c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\ # h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\ # # It can be decomposed into three phases: ``message_func``, # ``reduce_func`` and ``apply_node_func``. # # .. note:: # ``apply_node_func`` is a new node UDF we have not introduced before. In # ``apply_node_func``, user specifies what to do with node features, # without considering edge features and messages. In Tree-LSTM case, # ``apply_node_func`` is a must, since there exists (leaf) nodes with # :math:`0` incoming edges, which would not be updated via # ``reduce_func``. # import torch as th import torch.nn as nn class TreeLSTMCell(nn.Module): def __init__(self, x_size, h_size): super(TreeLSTMCell, self).__init__() self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False) self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False) self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size)) self.U_f = nn.Linear(2 * h_size, 2 * h_size) def message_func(self, edges): return {'h': edges.src['h'], 'c': edges.src['c']} def reduce_func(self, nodes): # concatenate h_jl for equation (1), (2), (3), (4) h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) # equation (2) f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) # second term of equation (5) c = th.sum(f * nodes.mailbox['c'], 1) return {'iou': self.U_iou(h_cat), 'c': c} def apply_node_func(self, nodes): # equation (1), (3), (4) iou = nodes.data['iou'] + self.b_iou i, o, u = th.chunk(iou, 3, 1) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) # equation (5) c = i * u + nodes.data['c'] # equation (6) h = o * th.tanh(c) return {'h' : h, 'c' : c} ############################################################################## # Step 3: define traversal # ------------------------ # # After defining the message passing functions, we then need to induce the # right order to trigger them. This is a significant departure from models # such as GCN, where all nodes are pulling messages from upstream ones # *simultaneously*. # # In the case of Tree-LSTM, messages start from leaves of the tree, and # propogate/processed upwards until they reach the roots. A visulization # is as follows: # # .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif # :alt: # # DGL defines a generator to perform the topological sort, each item is a # tensor recording the nodes from bottom level to the roots. One can # appreciate the degree of parallelism by inspecting the difference of the # followings: # print('Traversing one tree:') print(dgl.topological_nodes_generator(a_tree)) print('Traversing many trees at the same time:') print(dgl.topological_nodes_generator(graph)) ############################################################################## # We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing: import dgl.function as fn import torch as th graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1) graph.register_message_func(fn.copy_src('a', 'a')) graph.register_reduce_func(fn.sum('a', 'a')) traversal_order = dgl.topological_nodes_generator(graph) graph.prop_nodes(traversal_order) # the following is a syntax sugar that does the same # dgl.prop_nodes_topo(graph) ############################################################################## # .. note:: # # Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a # `message_func` and `reduce_func` in advance, here we use built-in # copy-from-source and sum function as our message function and reduce # function for demonstration. # # Putting it together # ------------------- # # Here is the complete code that specifies the ``Tree-LSTM`` class: # class TreeLSTM(nn.Module): def __init__(self, num_vocabs, x_size, h_size, num_classes, dropout, pretrained_emb=None): super(TreeLSTM, self).__init__() self.x_size = x_size self.embedding = nn.Embedding(num_vocabs, x_size) if pretrained_emb is not None: print('Using glove') self.embedding.weight.data.copy_(pretrained_emb) self.embedding.weight.requires_grad = True self.dropout = nn.Dropout(dropout) self.linear = nn.Linear(h_size, num_classes) self.cell = TreeLSTMCell(x_size, h_size) def forward(self, batch, h, c): """Compute tree-lstm prediction given a batch. Parameters ---------- batch : dgl.data.SSTBatch The data batch. h : Tensor Initial hidden state. c : Tensor Initial cell state. Returns ------- logits : Tensor The prediction of each node. """ g = batch.graph g.register_message_func(self.cell.message_func) g.register_reduce_func(self.cell.reduce_func) g.register_apply_node_func(self.cell.apply_node_func) # feed embedding embeds = self.embedding(batch.wordid * batch.mask) g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata['h'] = h g.ndata['c'] = c # propagate dgl.prop_nodes_topo(g) # compute logits h = self.dropout(g.ndata.pop('h')) logits = self.linear(h) return logits ############################################################################## # Main Loop # --------- # # Finally, we could write a training paradigm in PyTorch: # from torch.utils.data import DataLoader import torch.nn.functional as F device = th.device('cpu') # hyper parameters x_size = 256 h_size = 256 dropout = 0.5 lr = 0.05 weight_decay = 1e-4 epochs = 10 # create the model model = TreeLSTM(trainset.num_vocabs, x_size, h_size, trainset.num_classes, dropout) print(model) # create the optimizer optimizer = th.optim.Adagrad(model.parameters(), lr=lr, weight_decay=weight_decay) train_loader = DataLoader(dataset=tiny_sst, batch_size=5, collate_fn=SST.batcher(device), shuffle=False, num_workers=0) # training loop for epoch in range(epochs): for step, batch in enumerate(train_loader): g = batch.graph n = g.number_of_nodes() h = th.zeros((n, h_size)) c = th.zeros((n, h_size)) logits = model(batch, h, c) logp = F.log_softmax(logits, 1) loss = F.nll_loss(logp, batch.label, reduction='sum') optimizer.zero_grad() loss.backward() optimizer.step() pred = th.argmax(logits, 1) acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label) print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format( epoch, step, loss.item(), acc)) ############################################################################## # To train the model on full dataset with different settings(CPU/GPU, # etc.), please refer to our repo's # `example `__. # Besides, we also provide an implementation of the Child-Sum Tree LSTM.