Commit 41114b87 authored by John Andrilla's avatar John Andrilla Committed by Zihao Ye
Browse files

[Doc] Tree-LSTM in DGL, edit pass for readability (#1031)

* [Doc] Tree-LSTM in DGL, edit pass for readability

Edit for grammar and style. 
How about deleting the last line? "Besides, you..." It needs a transition or some context.

* Update tutorials/models/2_small_graph/3_tree-lstm.py

* Update tutorials/models/2_small_graph/3_tree-lstm.py
parent 632a9af8
""" """
.. _model-tree-lstm: .. _model-tree-lstm:
Tree LSTM DGL Tutorial Tutorial: Tree-LSTM in DGL
========================= ==========================
**Author**: Zihao Ye, Qipeng Guo, `Minjie Wang **Author**: Zihao Ye, Qipeng Guo, `Minjie Wang
<https://jermainewang.github.io/>`_, `Jake Zhao <https://jermainewang.github.io/>`_, `Jake Zhao
...@@ -11,29 +11,32 @@ Tree LSTM DGL Tutorial ...@@ -11,29 +11,32 @@ Tree LSTM DGL Tutorial
############################################################################## ##############################################################################
# #
# Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015 # In this tutorial, you learn to use Tree-LSTM networks for sentiment analysis.
# The Tree-LSTM is a generalization of long short-term memory (LSTM) networks to tree-structured network topologies.
#
# The Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015
# paper: `Improved Semantic Representations From Tree-Structured Long # paper: `Improved Semantic Representations From Tree-Structured Long
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__. # Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__.
# The core idea is to introduce syntactic information for language tasks by # The core idea is to introduce syntactic information for language tasks by
# extending the chain-structured LSTM to a tree-structured LSTM. The Dependency # extending the chain-structured LSTM to a tree-structured LSTM. The dependency
# Tree/Constituency Tree techniques were leveraged to obtain a ''latent tree''. # tree and constituency tree techniques are leveraged to obtain a ''latent tree''.
# #
# One, if not all, difficulty of training Tree-LSTMs is batching --- a standard # The challenge in training Tree-LSTMs is batching --- a standard
# technique in machine learning to accelerate optimization. However, since trees # technique in machine learning to accelerate optimization. However, since trees
# generally have different shapes by nature, parallization becomes non trivial. # generally have different shapes by nature, parallization is non-trivial.
# DGL offers an alternative: to pool all the trees into one single graph then # DGL offers an alternative. Pool all the trees into one single graph then
# induce the message passing over them guided by the structure of each tree. # induce the message passing over them, guided by the structure of each tree.
# #
# The task and the dataset # The task and the dataset
# ------------------------ # ------------------------
# In this tutorial, we will use Tree-LSTMs for sentiment analysis. #
# We have wrapped the # The steps here use the
# `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in # `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in
# ``dgl.data``. The dataset provides a fine-grained tree level sentiment # ``dgl.data``. The dataset provides a fine-grained, tree-level sentiment
# annotation: 5 classes(very negative, negative, neutral, positive, and # annotation. There are five classes: Very negative, negative, neutral, positive, and
# very positive) that indicates the sentiment in current subtree. Non-leaf # very positive, which indicate the sentiment in the current subtree. Non-leaf
# nodes in constituency tree does not contain words, we use a special # nodes in a constituency tree do not contain words, so use a special
# ``PAD_WORD`` token to denote them, during the training/inferencing, # ``PAD_WORD`` token to denote them. During training and inference
# their embeddings would be masked to all-zero. # their embeddings would be masked to all-zero.
# #
# .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png # .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png
...@@ -41,8 +44,8 @@ Tree LSTM DGL Tutorial ...@@ -41,8 +44,8 @@ Tree LSTM DGL Tutorial
# #
# The figure displays one sample of the SST dataset, which is a # The figure displays one sample of the SST dataset, which is a
# constituency parse tree with their nodes labeled with sentiment. To # constituency parse tree with their nodes labeled with sentiment. To
# speed up things, let's build a tiny set with 5 sentences and take a look # speed up things, build a tiny set with five sentences and take a look
# at the first one: # at the first one.
# #
import dgl import dgl
...@@ -50,10 +53,10 @@ from dgl.data.tree import SST ...@@ -50,10 +53,10 @@ from dgl.data.tree import SST
from dgl.data import SSTBatch from dgl.data import SSTBatch
# Each sample in the dataset is a constituency tree. The leaf nodes # Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is a int value stored in the "x" field. # represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes has a special word PAD_WORD. The sentiment # The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field. # label is stored in the "y" feature field.
trainset = SST(mode='tiny') # the "tiny" set has only 5 trees trainset = SST(mode='tiny') # the "tiny" set has only five trees
tiny_sst = trainset.trees tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes num_classes = trainset.num_classes
...@@ -67,10 +70,10 @@ for token in a_tree.ndata['x'].tolist(): ...@@ -67,10 +70,10 @@ for token in a_tree.ndata['x'].tolist():
print(inv_vocab[token], end=" ") print(inv_vocab[token], end=" ")
############################################################################## ##############################################################################
# Step 1: batching # Step 1: Batching
# ---------------- # ----------------
# #
# The first step is to throw all the trees into one graph, using # Add all the trees to one graph, using
# the :func:`~dgl.batched_graph.batch` API. # the :func:`~dgl.batched_graph.batch` API.
# #
...@@ -87,20 +90,19 @@ def plot_tree(g): ...@@ -87,20 +90,19 @@ def plot_tree(g):
plot_tree(graph.to_networkx()) plot_tree(graph.to_networkx())
############################################################################## #################################################################################
# You can read more about the definition of :func:`~dgl.batched_graph.batch` # You can read more about the definition of :func:`~dgl.batched_graph.batch`, or
# (by clicking the API), or can skip ahead to the next step: # skip ahead to the next step:
#
# .. note:: # .. note::
# #
# **Definition**: a :class:`~dgl.batched_graph.BatchedDGLGraph` is a # **Definition**: A :class:`~dgl.batched_graph.BatchedDGLGraph` is a
# :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s. # :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s.
# #
# - The union includes all the nodes, # - The union includes all the nodes,
# edges, and their features. The order of nodes, edges and features are # edges, and their features. The order of nodes, edges, and features are
# preserved. # preserved.
# #
# - Given that we have :math:`V_i` nodes for graph # - Given that you have :math:`V_i` nodes for graph
# :math:`\mathcal{G}_i`, the node ID :math:`j` in graph # :math:`\mathcal{G}_i`, the node ID :math:`j` in graph
# :math:`\mathcal{G}_i` correspond to node ID # :math:`\mathcal{G}_i` correspond to node ID
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph. # :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
...@@ -113,8 +115,8 @@ plot_tree(graph.to_networkx()) ...@@ -113,8 +115,8 @@ plot_tree(graph.to_networkx())
# treated as deep copies; the nodes, edges, and features are duplicated, # treated as deep copies; the nodes, edges, and features are duplicated,
# and mutation on one reference does not affect the other. # and mutation on one reference does not affect the other.
# - Currently, ``BatchedDGLGraph`` is immutable in # - Currently, ``BatchedDGLGraph`` is immutable in
# graph structure (i.e. one can't add # graph structure. You can't add
# nodes and edges to it). We need to support mutable batched graphs in # nodes and edges to it. You need to support mutable batched graphs in
# (far) future. # (far) future.
# - The ``BatchedDGLGraph`` keeps track of the meta # - The ``BatchedDGLGraph`` keeps track of the meta
# information of the constituents so it can be # information of the constituents so it can be
...@@ -123,19 +125,19 @@ plot_tree(graph.to_networkx()) ...@@ -123,19 +125,19 @@ plot_tree(graph.to_networkx())
# For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph` # For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`
# module in DGL, you can click the class name. # module in DGL, you can click the class name.
# #
# Step 2: Tree-LSTM Cell with message-passing APIs # Step 2: Tree-LSTM cell with message-passing APIs
# ------------------------------------------------ # ------------------------------------------------
# #
# The authors proposed two types of Tree LSTM: Child-Sum # Researchers have proposed two types of Tree-LSTMs: Child-Sum
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus # Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial you focus
# on applying *Binary* Tree-LSTM to binarized constituency trees(this # on applying *Binary* Tree-LSTM to binarized constituency trees. This
# application is also known as *Constituency Tree-LSTM*). We use PyTorch # application is also known as *Constituency Tree-LSTM*. Use PyTorch
# as our backend framework to set up the network. # as a backend framework to set up the network.
# #
# In `N`-ary Tree LSTM, each unit at node :math:`j` maintains a hidden # In `N`-ary Tree-LSTM, each unit at node :math:`j` maintains a hidden
# representation :math:`h_j` and a memory cell :math:`c_j`. The unit # representation :math:`h_j` and a memory cell :math:`c_j`. The unit
# :math:`j` takes the input vector :math:`x_j` and the hidden # :math:`j` takes the input vector :math:`x_j` and the hidden
# representations of the their child units: :math:`h_{jl}, 1\leq l\leq N` as # representations of the child units: :math:`h_{jl}, 1\leq l\leq N` as
# input, then update its new hidden representation :math:`h_j` and memory # input, then update its new hidden representation :math:`h_j` and memory
# cell :math:`c_j` by: # cell :math:`c_j` by:
# #
...@@ -152,11 +154,11 @@ plot_tree(graph.to_networkx()) ...@@ -152,11 +154,11 @@ plot_tree(graph.to_networkx())
# ``reduce_func`` and ``apply_node_func``. # ``reduce_func`` and ``apply_node_func``.
# #
# .. note:: # .. note::
# ``apply_node_func`` is a new node UDF we have not introduced before. In # ``apply_node_func`` is a new node UDF that has not been introduced before. In
# ``apply_node_func``, user specifies what to do with node features, # ``apply_node_func``, a user specifies what to do with node features,
# without considering edge features and messages. In Tree-LSTM case, # without considering edge features and messages. In a Tree-LSTM case,
# ``apply_node_func`` is a must, since there exists (leaf) nodes with # ``apply_node_func`` is a must, since there exists (leaf) nodes with
# :math:`0` incoming edges, which would not be updated via # :math:`0` incoming edges, which would not be updated with
# ``reduce_func``. # ``reduce_func``.
# #
...@@ -195,10 +197,10 @@ class TreeLSTMCell(nn.Module): ...@@ -195,10 +197,10 @@ class TreeLSTMCell(nn.Module):
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
############################################################################## ##############################################################################
# Step 3: define traversal # Step 3: Define traversal
# ------------------------ # ------------------------
# #
# After defining the message passing functions, we then need to induce the # After you define the message-passing functions, induce the
# right order to trigger them. This is a significant departure from models # right order to trigger them. This is a significant departure from models
# such as GCN, where all nodes are pulling messages from upstream ones # such as GCN, where all nodes are pulling messages from upstream ones
# *simultaneously*. # *simultaneously*.
...@@ -223,7 +225,7 @@ print('Traversing many trees at the same time:') ...@@ -223,7 +225,7 @@ print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph)) print(dgl.topological_nodes_generator(graph))
############################################################################## ##############################################################################
# We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing: # Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
import dgl.function as fn import dgl.function as fn
import torch as th import torch as th
...@@ -241,15 +243,15 @@ graph.prop_nodes(traversal_order) ...@@ -241,15 +243,15 @@ graph.prop_nodes(traversal_order)
############################################################################## ##############################################################################
# .. note:: # .. note::
# #
# Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a # Before you call :meth:`~dgl.DGLGraph.prop_nodes`, specify a
# `message_func` and `reduce_func` in advance, here we use built-in # `message_func` and `reduce_func` in advance. In the example, you can see built-in
# copy-from-source and sum function as our message function and reduce # copy-from-source and sum functions as message functions, and a reduce
# function for demonstration. # function for demonstration.
# #
# Putting it together # Putting it together
# ------------------- # -------------------
# #
# Here is the complete code that specifies the ``Tree-LSTM`` class: # Here is the complete code that specifies the ``Tree-LSTM`` class.
# #
class TreeLSTM(nn.Module): class TreeLSTM(nn.Module):
...@@ -308,7 +310,7 @@ class TreeLSTM(nn.Module): ...@@ -308,7 +310,7 @@ class TreeLSTM(nn.Module):
# Main Loop # Main Loop
# --------- # ---------
# #
# Finally, we could write a training paradigm in PyTorch: # Finally, you could write a training paradigm in PyTorch.
# #
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
...@@ -370,7 +372,6 @@ for epoch in range(epochs): ...@@ -370,7 +372,6 @@ for epoch in range(epochs):
epoch, step, loss.item(), acc)) epoch, step, loss.item(), acc))
############################################################################## ##############################################################################
# To train the model on full dataset with different settings(CPU/GPU, # To train the model on a full dataset with different settings (such as CPU or GPU),
# etc.), please refer to our repo's # refer to the `PyTorch example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm>`__.
# `example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm>`__. # There is also an implementation of the Child-Sum Tree-LSTM.
# Besides, we also provide an implementation of the Child-Sum Tree LSTM.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment