Commit 16cc5287 authored by VoVAllen's avatar VoVAllen Committed by Minjie Wang
Browse files

[Doc] Improve Capsule with Jinyang & Fix wrong tutorial level layout (#236)

* improve capsule tutorial with jinyang

* fix wrong layout of second-level tutorial

* delete transformer
parent dafe4671
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# #
# Configuration file for the Sphinx documentation builder. # Configuration file for the Sphinx documentation builder.
# #
# This file does only contain a selection of the most common options. For a # This file does only contain a selection of the most common options. For a
# full list see the documentation: # full list see the documentation:
# http://www.sphinx-doc.org/en/master/config # http://www.sphinx-doc.org/en/master/config
# -- Path setup -------------------------------------------------------------- # -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
# #
import os import os
import sys import sys
sys.path.insert(0, os.path.abspath('../../python')) sys.path.insert(0, os.path.abspath('../../python'))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------
project = 'DGL' project = 'DGL'
copyright = '2018, DGL Team' copyright = '2018, DGL Team'
author = 'DGL Team' author = 'DGL Team'
# The short X.Y version # The short X.Y version
version = '0.0.1' version = '0.0.1'
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
release = '0.0.1' release = '0.0.1'
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here. # If your documentation needs a minimal Sphinx version, state it here.
# #
# needs_sphinx = '1.0' # needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be # Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones. # ones.
extensions = [ extensions = [
'sphinx.ext.autodoc', 'sphinx.ext.autodoc',
'sphinx.ext.autosummary', 'sphinx.ext.autosummary',
'sphinx.ext.coverage', 'sphinx.ext.coverage',
'sphinx.ext.mathjax', 'sphinx.ext.mathjax',
'sphinx.ext.napoleon', 'sphinx.ext.napoleon',
'sphinx.ext.viewcode', 'sphinx.ext.viewcode',
'sphinx.ext.intersphinx', 'sphinx.ext.intersphinx',
'sphinx.ext.graphviz', 'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery', 'sphinx_gallery.gen_gallery',
] ]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
source_suffix = ['.rst', '.md'] source_suffix = ['.rst', '.md']
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
# #
# This is also used if you do content translation via gettext catalogs. # This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases. # Usually you set "language" from the command line for these cases.
language = None language = None
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path. # This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [] exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = None pygments_style = None
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
# #
html_theme = 'sphinx_rtd_theme' html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
# documentation. # documentation.
# #
# html_theme_options = {} # html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names # Custom sidebar templates, must be a dictionary that maps document names
# to template names. # to template names.
# #
# The default sidebars (for documents that don't match any pattern) are # The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by # defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``. # 'searchbox.html']``.
# #
# html_sidebars = {} # html_sidebars = {}
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'dgldoc' htmlhelp_basename = 'dgldoc'
# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------
latex_elements = { latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
# #
# 'papersize': 'letterpaper', # 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
# #
# 'pointsize': '10pt', # 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# 'preamble': '', # 'preamble': '',
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# 'figure_align': 'htbp', # 'figure_align': 'htbp',
} }
# Grouping the document tree into LaTeX files. List of tuples # Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, # (source start file, target name, title,
# author, documentclass [howto, manual, or own class]). # author, documentclass [howto, manual, or own class]).
latex_documents = [ latex_documents = [
(master_doc, 'dgl.tex', 'DGL Documentation', (master_doc, 'dgl.tex', 'DGL Documentation',
'DGL Team', 'manual'), 'DGL Team', 'manual'),
] ]
# -- Options for manual page output ------------------------------------------ # -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [
(master_doc, 'dgl', 'DGL Documentation', (master_doc, 'dgl', 'DGL Documentation',
[author], 1) [author], 1)
] ]
# -- Options for Texinfo output ---------------------------------------------- # -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples # Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
(master_doc, 'dgl', 'DGL Documentation', (master_doc, 'dgl', 'DGL Documentation',
author, 'dgl', 'Library for deep learning on graphs.', author, 'dgl', 'Library for deep learning on graphs.',
'Miscellaneous'), 'Miscellaneous'),
] ]
# -- Options for Epub output ------------------------------------------------- # -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info. # Bibliographic Dublin Core info.
epub_title = project epub_title = project
# The unique identifier of the text. This can be a ISBN number # The unique identifier of the text. This can be a ISBN number
# or the project homepage. # or the project homepage.
# #
# epub_identifier = '' # epub_identifier = ''
# A unique identification for the text. # A unique identification for the text.
# #
# epub_uid = '' # epub_uid = ''
# A list of files that should not be packed into the epub file. # A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html'] epub_exclude_files = ['search.html']
# -- Extension configuration ------------------------------------------------- # -- Extension configuration -------------------------------------------------
autosummary_generate = True autosummary_generate = True
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None), 'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None), 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'scipy': ('http://docs.scipy.org/doc/scipy/reference', None), 'scipy': ('http://docs.scipy.org/doc/scipy/reference', None),
'matplotlib': ('http://matplotlib.org/', None), 'matplotlib': ('http://matplotlib.org/', None),
'networkx' : ('https://networkx.github.io/documentation/stable', None), 'networkx' : ('https://networkx.github.io/documentation/stable', None),
} }
# sphinx gallery configurations # sphinx gallery configurations
from sphinx_gallery.sorting import FileNameSortKey from sphinx_gallery.sorting import FileNameSortKey
examples_dirs = ['../../tutorials'] # path to find sources examples_dirs = ['../../tutorials/basics','../../tutorials/models'] # path to find sources
gallery_dirs = ['tutorials'] # path to generate docs gallery_dirs = ['tutorials/basics','tutorials/models'] # path to generate docs
reference_url = { reference_url = {
'dgl' : None, 'dgl' : None,
'numpy': 'http://docs.scipy.org/doc/numpy/', 'numpy': 'http://docs.scipy.org/doc/numpy/',
'scipy': 'http://docs.scipy.org/doc/scipy/reference', 'scipy': 'http://docs.scipy.org/doc/scipy/reference',
'matplotlib': 'http://matplotlib.org/', 'matplotlib': 'http://matplotlib.org/',
'networkx' : 'https://networkx.github.io/documentation/stable', 'networkx' : 'https://networkx.github.io/documentation/stable',
} }
sphinx_gallery_conf = { sphinx_gallery_conf = {
'backreferences_dir' : 'generated/backreferences', 'backreferences_dir' : 'generated/backreferences',
'doc_module' : ('dgl', 'numpy'), 'doc_module' : ('dgl', 'numpy'),
'examples_dirs' : examples_dirs, 'examples_dirs' : examples_dirs,
'gallery_dirs' : gallery_dirs, 'gallery_dirs' : gallery_dirs,
'within_subsection_order' : FileNameSortKey, 'within_subsection_order' : FileNameSortKey,
'filename_pattern' : '.py', 'filename_pattern' : '.py',
} }
...@@ -65,7 +65,8 @@ credit, see `here <https://www.dgl.ai/ack>`_. ...@@ -65,7 +65,8 @@ credit, see `here <https://www.dgl.ai/ack>`_.
:caption: Tutorials :caption: Tutorials
:glob: :glob:
tutorials/index tutorials/basics/index
tutorials/models/index
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
......
""" """
.. currentmodule:: dgl .. currentmodule:: dgl
DGL at a Glance DGL at a Glance
========================= =========================
**Author**: `Minjie Wang <https://jermainewang.github.io/>`_, Quan Gan, `Jake **Author**: `Minjie Wang <https://jermainewang.github.io/>`_, Quan Gan, `Jake
Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang
DGL is a Python package dedicated to deep learning on graphs, built atop DGL is a Python package dedicated to deep learning on graphs, built atop
existing tensor DL frameworks (e.g. Pytorch, MXNet) and simplifying the existing tensor DL frameworks (e.g. Pytorch, MXNet) and simplifying the
implementation of graph-based neural networks. implementation of graph-based neural networks.
The goal of this tutorial: The goal of this tutorial:
- Understand how DGL enables computation on graph from a high level. - Understand how DGL enables computation on graph from a high level.
- Train a simple graph neural network in DGL to classify nodes in a graph. - Train a simple graph neural network in DGL to classify nodes in a graph.
At the end of this tutorial, we hope you get a brief feeling of how DGL works. At the end of this tutorial, we hope you get a brief feeling of how DGL works.
*This tutorial assumes basic familiarity with pytorch.* *This tutorial assumes basic familiarity with pytorch.*
""" """
############################################################################### ###############################################################################
# Step 0: Problem description # Step 0: Problem description
# --------------------------- # ---------------------------
# #
# We start with the well-known "Zachary's karate club" problem. The karate club # We start with the well-known "Zachary's karate club" problem. The karate club
# is a social network which captures 34 members and document pairwise links # is a social network which captures 34 members and document pairwise links
# between members who interact outside the club. The club later divides into # between members who interact outside the club. The club later divides into
# two communities led by the instructor (node 0) and the club president (node # two communities led by the instructor (node 0) and the club president (node
# 33). The network is visualized as follows with the color indicating the # 33). The network is visualized as follows with the color indicating the
# community: # community:
# #
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/img/karate-club.png # .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/img/karate-club.png
# :align: center # :align: center
# #
# The task is to predict which side (0 or 33) each member tends to join given # The task is to predict which side (0 or 33) each member tends to join given
# the social network itself. # the social network itself.
############################################################################### ###############################################################################
# Step 1: Creating a graph in DGL # Step 1: Creating a graph in DGL
# ------------------------------- # -------------------------------
# Creating the graph for Zachary's karate club goes as follows: # Creating the graph for Zachary's karate club goes as follows:
import dgl import dgl
def build_karate_club_graph(): def build_karate_club_graph():
g = dgl.DGLGraph() g = dgl.DGLGraph()
# add 34 nodes into the graph; nodes are labeled from 0~33 # add 34 nodes into the graph; nodes are labeled from 0~33
g.add_nodes(34) g.add_nodes(34)
# all 78 edges as a list of tuples # all 78 edges as a list of tuples
edge_list = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2), edge_list = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2),
(4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1), (4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1),
(7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4), (7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4),
(10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2), (10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2),
(13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1), (13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1),
(21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23), (21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23),
(27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8), (27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8),
(31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8), (31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8),
(32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23), (32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23),
(32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13), (32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13),
(33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22), (33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22),
(33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30), (33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30),
(33, 31), (33, 32)] (33, 31), (33, 32)]
# add edges two lists of nodes: src and dst # add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list)) src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst) g.add_edges(src, dst)
# edges are directional in DGL; make them bi-directional # edges are directional in DGL; make them bi-directional
g.add_edges(dst, src) g.add_edges(dst, src)
return g return g
############################################################################### ###############################################################################
# We can print out the number of nodes and edges in our newly constructed graph: # We can print out the number of nodes and edges in our newly constructed graph:
G = build_karate_club_graph() G = build_karate_club_graph()
print('We have %d nodes.' % G.number_of_nodes()) print('We have %d nodes.' % G.number_of_nodes())
print('We have %d edges.' % G.number_of_edges()) print('We have %d edges.' % G.number_of_edges())
############################################################################### ###############################################################################
# We can also visualize the graph by converting it to a `networkx # We can also visualize the graph by converting it to a `networkx
# <https://networkx.github.io/documentation/stable/>`_ graph: # <https://networkx.github.io/documentation/stable/>`_ graph:
import networkx as nx import networkx as nx
# Since the actual graph is undirected, we convert it for visualization # Since the actual graph is undirected, we convert it for visualization
# purpose. # purpose.
nx_G = G.to_networkx().to_undirected() nx_G = G.to_networkx().to_undirected()
# Kamada-Kawaii layout usually looks pretty for arbitrary graphs # Kamada-Kawaii layout usually looks pretty for arbitrary graphs
pos = nx.kamada_kawai_layout(nx_G) pos = nx.kamada_kawai_layout(nx_G)
nx.draw(nx_G, pos, with_labels=True, node_color=[[.7, .7, .7]]) nx.draw(nx_G, pos, with_labels=True, node_color=[[.7, .7, .7]])
############################################################################### ###############################################################################
# Step 2: assign features to nodes or edges # Step 2: assign features to nodes or edges
# -------------------------------------------- # --------------------------------------------
# Graph neural networks associate features with nodes and edges for training. # Graph neural networks associate features with nodes and edges for training.
# For our classification example, we assign each node's an input feature as a one-hot vector: # For our classification example, we assign each node's an input feature as a one-hot vector:
# node :math:`v_i`'s feature vector is :math:`[0,\ldots,1,\dots,0]`, # node :math:`v_i`'s feature vector is :math:`[0,\ldots,1,\dots,0]`,
# where the :math:`i^{th}` position is one. # where the :math:`i^{th}` position is one.
# #
# In DGL, we can add features for all nodes at once, using a feature tensor that # In DGL, we can add features for all nodes at once, using a feature tensor that
# batches node features along the first dimension. This code below adds the one-hot # batches node features along the first dimension. This code below adds the one-hot
# feature for all nodes: # feature for all nodes:
import torch import torch
G.ndata['feat'] = torch.eye(34) G.ndata['feat'] = torch.eye(34)
############################################################################### ###############################################################################
# We can print out the node features to verify: # We can print out the node features to verify:
# print out node 2's input feature # print out node 2's input feature
print(G.nodes[2].data['feat']) print(G.nodes[2].data['feat'])
# print out node 10 and 11's input features # print out node 10 and 11's input features
print(G.nodes[[10, 11]].data['feat']) print(G.nodes[[10, 11]].data['feat'])
############################################################################### ###############################################################################
# Step 3: define a Graph Convolutional Network (GCN) # Step 3: define a Graph Convolutional Network (GCN)
# -------------------------------------------------- # --------------------------------------------------
# To perform node classification, we use the Graph Convolutional Network # To perform node classification, we use the Graph Convolutional Network
# (GCN) developed by `Kipf and Welling <https://arxiv.org/abs/1609.02907>`_. Here # (GCN) developed by `Kipf and Welling <https://arxiv.org/abs/1609.02907>`_. Here
# we provide the simpliest definition of a GCN framework, but we recommend the # we provide the simpliest definition of a GCN framework, but we recommend the
# reader to read the original paper for more details. # reader to read the original paper for more details.
# #
# - At layer :math:`l`, each node :math:`v_i^l` carries a feature vector :math:`h_i^l`. # - At layer :math:`l`, each node :math:`v_i^l` carries a feature vector :math:`h_i^l`.
# - Each layer of the GCN tries to aggregate the features from :math:`u_i^{l}` where # - Each layer of the GCN tries to aggregate the features from :math:`u_i^{l}` where
# :math:`u_i`'s are neighborhood nodes to :math:`v` into the next layer representation at # :math:`u_i`'s are neighborhood nodes to :math:`v` into the next layer representation at
# :math:`v_i^{l+1}`. This is followed by an affine transformation with some # :math:`v_i^{l+1}`. This is followed by an affine transformation with some
# non-linearity. # non-linearity.
# #
# The above definition of GCN fits into a **message-passing** paradigm: each # The above definition of GCN fits into a **message-passing** paradigm: each
# node will update its own feature with information sent from neighboring # node will update its own feature with information sent from neighboring
# nodes. A graphical demonstration is displayed below. # nodes. A graphical demonstration is displayed below.
# #
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/1_first/mailbox.png # .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/1_first/mailbox.png
# :alt: mailbox # :alt: mailbox
# :align: center # :align: center
# #
# Now, we show that the GCN layer can be easily implemented in DGL. # Now, we show that the GCN layer can be easily implemented in DGL.
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
# Define the message & reduce function # Define the message & reduce function
# NOTE: we ignore the GCN's normalization constant c_ij for this tutorial. # NOTE: we ignore the GCN's normalization constant c_ij for this tutorial.
def gcn_message(edges): def gcn_message(edges):
# The argument is a batch of edges. # The argument is a batch of edges.
# This computes a (batch of) message called 'msg' using the source node's feature 'h'. # This computes a (batch of) message called 'msg' using the source node's feature 'h'.
return {'msg' : edges.src['h']} return {'msg' : edges.src['h']}
def gcn_reduce(nodes): def gcn_reduce(nodes):
# The argument is a batch of nodes. # The argument is a batch of nodes.
# This computes the new 'h' features by summing received 'msg' in each node's mailbox. # This computes the new 'h' features by summing received 'msg' in each node's mailbox.
return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)} return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}
# Define the GCNLayer module # Define the GCNLayer module
class GCNLayer(nn.Module): class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats): def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__() super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, inputs): def forward(self, g, inputs):
# g is the graph and the inputs is the input node features # g is the graph and the inputs is the input node features
# first set the node features # first set the node features
g.ndata['h'] = inputs g.ndata['h'] = inputs
# trigger message passing on all edges # trigger message passing on all edges
g.send(g.edges(), gcn_message) g.send(g.edges(), gcn_message)
# trigger aggregation at all nodes # trigger aggregation at all nodes
g.recv(g.nodes(), gcn_reduce) g.recv(g.nodes(), gcn_reduce)
# get the result node features # get the result node features
h = g.ndata.pop('h') h = g.ndata.pop('h')
# perform linear transformation # perform linear transformation
return self.linear(h) return self.linear(h)
############################################################################### ###############################################################################
# In general, the nodes send information computed via the *message functions*, # In general, the nodes send information computed via the *message functions*,
# and aggregates incoming information with the *reduce functions*. # and aggregates incoming information with the *reduce functions*.
# #
# We then define a deeper GCN model that contains two GCN layers: # We then define a deeper GCN model that contains two GCN layers:
# Define a 2-layer GCN model # Define a 2-layer GCN model
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, in_feats, hidden_size, num_classes): def __init__(self, in_feats, hidden_size, num_classes):
super(GCN, self).__init__() super(GCN, self).__init__()
self.gcn1 = GCNLayer(in_feats, hidden_size) self.gcn1 = GCNLayer(in_feats, hidden_size)
self.gcn2 = GCNLayer(hidden_size, num_classes) self.gcn2 = GCNLayer(hidden_size, num_classes)
def forward(self, g, inputs): def forward(self, g, inputs):
h = self.gcn1(g, inputs) h = self.gcn1(g, inputs)
h = torch.relu(h) h = torch.relu(h)
h = self.gcn2(g, h) h = self.gcn2(g, h)
return h return h
# The first layer transforms input features of size of 34 to a hidden size of 5. # The first layer transforms input features of size of 34 to a hidden size of 5.
# The second layer transforms the hidden layer and produces output features of # The second layer transforms the hidden layer and produces output features of
# size 2, corresponding to the two groups of the karate club. # size 2, corresponding to the two groups of the karate club.
net = GCN(34, 5, 2) net = GCN(34, 5, 2)
############################################################################### ###############################################################################
# Step 4: data preparation and initialization # Step 4: data preparation and initialization
# ------------------------------------------- # -------------------------------------------
# #
# We use one-hot vectors to initialize the node features. Since this is a # We use one-hot vectors to initialize the node features. Since this is a
# semi-supervised setting, only the instructor (node 0) and the club president # semi-supervised setting, only the instructor (node 0) and the club president
# (node 33) are assigned labels. The implementation is available as follow. # (node 33) are assigned labels. The implementation is available as follow.
inputs = torch.eye(34) inputs = torch.eye(34)
labeled_nodes = torch.tensor([0, 33]) # only the instructor and the president nodes are labeled labeled_nodes = torch.tensor([0, 33]) # only the instructor and the president nodes are labeled
labels = torch.tensor([0, 1]) # their labels are different labels = torch.tensor([0, 1]) # their labels are different
############################################################################### ###############################################################################
# Step 5: train then visualize # Step 5: train then visualize
# ---------------------------- # ----------------------------
# The training loop is exactly the same as other PyTorch models. # The training loop is exactly the same as other PyTorch models.
# We (1) create an optimizer, (2) feed the inputs to the model, # We (1) create an optimizer, (2) feed the inputs to the model,
# (3) calculate the loss and (4) use autograd to optimize the model. # (3) calculate the loss and (4) use autograd to optimize the model.
optimizer = torch.optim.Adam(net.parameters(), lr=0.01) optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
all_logits = [] all_logits = []
for epoch in range(30): for epoch in range(30):
logits = net(G, inputs) logits = net(G, inputs)
# we save the logits for visualization later # we save the logits for visualization later
all_logits.append(logits.detach()) all_logits.append(logits.detach())
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
# we only compute loss for labeled nodes # we only compute loss for labeled nodes
loss = F.nll_loss(logp[labeled_nodes], labels) loss = F.nll_loss(logp[labeled_nodes], labels)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
print('Epoch %d | Loss: %.4f' % (epoch, loss.item())) print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))
############################################################################### ###############################################################################
# This is a rather toy example, so it does not even have a validation or test # This is a rather toy example, so it does not even have a validation or test
# set. Instead, Since the model produces an output feature of size 2 for each node, we can # set. Instead, Since the model produces an output feature of size 2 for each node, we can
# visualize by plotting the output feature in a 2D space. # visualize by plotting the output feature in a 2D space.
# The following code animates the training process from initial guess # The following code animates the training process from initial guess
# (where the nodes are not classified correctly at all) to the end # (where the nodes are not classified correctly at all) to the end
# (where the nodes are linearly separable). # (where the nodes are linearly separable).
import matplotlib.animation as animation import matplotlib.animation as animation
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def draw(i): def draw(i):
cls1color = '#00FFFF' cls1color = '#00FFFF'
cls2color = '#FF00FF' cls2color = '#FF00FF'
pos = {} pos = {}
colors = [] colors = []
for v in range(34): for v in range(34):
pos[v] = all_logits[i][v].numpy() pos[v] = all_logits[i][v].numpy()
cls = pos[v].argmax() cls = pos[v].argmax()
colors.append(cls1color if cls else cls2color) colors.append(cls1color if cls else cls2color)
ax.cla() ax.cla()
ax.axis('off') ax.axis('off')
ax.set_title('Epoch: %d' % i) ax.set_title('Epoch: %d' % i)
nx.draw_networkx(nx_G.to_undirected(), pos, node_color=colors, nx.draw_networkx(nx_G.to_undirected(), pos, node_color=colors,
with_labels=True, node_size=300, ax=ax) with_labels=True, node_size=300, ax=ax)
fig = plt.figure(dpi=150) fig = plt.figure(dpi=150)
fig.clf() fig.clf()
ax = fig.subplots() ax = fig.subplots()
draw(0) # draw the prediction of the first epoch draw(0) # draw the prediction of the first epoch
plt.close() plt.close()
############################################################################### ###############################################################################
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/1_first/karate0.png # .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/1_first/karate0.png
# :height: 300px # :height: 300px
# :width: 400px # :width: 400px
# :align: center # :align: center
############################################################################### ###############################################################################
# The following animation shows how the model correctly predicts the community # The following animation shows how the model correctly predicts the community
# after a series of training epochs. # after a series of training epochs.
ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200) ani = animation.FuncAnimation(fig, draw, frames=len(all_logits), interval=200)
############################################################################### ###############################################################################
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/1_first/karate.gif # .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/1_first/karate.gif
# :height: 300px # :height: 300px
# :width: 400px # :width: 400px
# :align: center # :align: center
############################################################################### ###############################################################################
# Next steps # Next steps
# ---------- # ----------
# #
# In the :doc:`next tutorial <2_basics>`, we will go through some more basics # In the :doc:`next tutorial <2_basics>`, we will go through some more basics
# of DGL, such as reading and writing node/edge features. # of DGL, such as reading and writing node/edge features.
""" """
.. currentmodule:: dgl .. currentmodule:: dgl
DGL Basics DGL Basics
========== ==========
**Author**: `Minjie Wang <https://jermainewang.github.io/>`_, Quan Gan, Yu Gai, **Author**: `Minjie Wang <https://jermainewang.github.io/>`_, Quan Gan, Yu Gai,
Zheng Zhang Zheng Zhang
The Goal of this tutorial: The Goal of this tutorial:
* To create a graph. * To create a graph.
* To read and write node and edge representations. * To read and write node and edge representations.
""" """
############################################################################### ###############################################################################
# Graph Creation # Graph Creation
# -------------- # --------------
# The design of :class:`DGLGraph` was influenced by other graph libraries. Indeed, # The design of :class:`DGLGraph` was influenced by other graph libraries. Indeed,
# you can create a graph from networkx, and convert it into a :class:`DGLGraph` and # you can create a graph from networkx, and convert it into a :class:`DGLGraph` and
# vice versa: # vice versa:
import networkx as nx import networkx as nx
import dgl import dgl
g_nx = nx.petersen_graph() g_nx = nx.petersen_graph()
g_dgl = dgl.DGLGraph(g_nx) g_dgl = dgl.DGLGraph(g_nx)
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
plt.subplot(121) plt.subplot(121)
nx.draw(g_nx, with_labels=True) nx.draw(g_nx, with_labels=True)
plt.subplot(122) plt.subplot(122)
nx.draw(g_dgl.to_networkx(), with_labels=True) nx.draw(g_dgl.to_networkx(), with_labels=True)
plt.show() plt.show()
############################################################################### ###############################################################################
# They are the same graph, except that :class:`DGLGraph` is *always* directional. # They are the same graph, except that :class:`DGLGraph` is *always* directional.
# #
# One can also create a graph by calling DGL's own interface. # One can also create a graph by calling DGL's own interface.
# #
# Now let's build a star graph. :class:`DGLGraph` nodes are consecutive range of # Now let's build a star graph. :class:`DGLGraph` nodes are consecutive range of
# integers between 0 and :func:`number_of_nodes() <DGLGraph.number_of_nodes>` # integers between 0 and :func:`number_of_nodes() <DGLGraph.number_of_nodes>`
# and can grow by calling :func:`add_nodes <DGLGraph.add_nodes>`. # and can grow by calling :func:`add_nodes <DGLGraph.add_nodes>`.
# :class:`DGLGraph` edges are in order of their additions. Note that # :class:`DGLGraph` edges are in order of their additions. Note that
# edges are accessed in much the same way as nodes, with one extra feature # edges are accessed in much the same way as nodes, with one extra feature
# of *edge broadcasting*: # of *edge broadcasting*:
import dgl import dgl
import torch as th import torch as th
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(10) g.add_nodes(10)
# a couple edges one-by-one # a couple edges one-by-one
for i in range(1, 4): for i in range(1, 4):
g.add_edge(i, 0) g.add_edge(i, 0)
# a few more with a paired list # a few more with a paired list
src = list(range(5, 8)); dst = [0]*3 src = list(range(5, 8)); dst = [0]*3
g.add_edges(src, dst) g.add_edges(src, dst)
# finish with a pair of tensors # finish with a pair of tensors
src = th.tensor([8, 9]); dst = th.tensor([0, 0]) src = th.tensor([8, 9]); dst = th.tensor([0, 0])
g.add_edges(src, dst) g.add_edges(src, dst)
# edge broadcasting will do star graph in one go! # edge broadcasting will do star graph in one go!
g.clear(); g.add_nodes(10) g.clear(); g.add_nodes(10)
src = th.tensor(list(range(1, 10))); src = th.tensor(list(range(1, 10)));
g.add_edges(src, 0) g.add_edges(src, 0)
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
nx.draw(g.to_networkx(), with_labels=True) nx.draw(g.to_networkx(), with_labels=True)
plt.show() plt.show()
############################################################################### ###############################################################################
# Feature Assignment # Feature Assignment
# ------------------ # ------------------
# One can also assign features to nodes and edges of a :class:`DGLGraph`. The # One can also assign features to nodes and edges of a :class:`DGLGraph`. The
# features are represented as dictionary of names (strings) and tensors, # features are represented as dictionary of names (strings) and tensors,
# called **fields**. # called **fields**.
# #
# The following code snippet assigns each node a vector (len=3). # The following code snippet assigns each node a vector (len=3).
# #
# .. note:: # .. note::
# #
# DGL aims to be framework-agnostic, and currently it supports PyTorch and # DGL aims to be framework-agnostic, and currently it supports PyTorch and
# MXNet tensors. From now on, we use PyTorch as an example. # MXNet tensors. From now on, we use PyTorch as an example.
import dgl import dgl
import torch as th import torch as th
x = th.randn(10, 3) x = th.randn(10, 3)
g.ndata['x'] = x g.ndata['x'] = x
############################################################################### ###############################################################################
# :func:`ndata <DGLGraph.ndata>` is a syntax sugar to access states of all nodes, # :func:`ndata <DGLGraph.ndata>` is a syntax sugar to access states of all nodes,
# states are stored # states are stored
# in a container ``data`` that hosts user defined dictionary. # in a container ``data`` that hosts user defined dictionary.
print(g.ndata['x'] == g.nodes[:].data['x']) print(g.ndata['x'] == g.nodes[:].data['x'])
# access node set with integer, list, or integer tensor # access node set with integer, list, or integer tensor
g.nodes[0].data['x'] = th.zeros(1, 3) g.nodes[0].data['x'] = th.zeros(1, 3)
g.nodes[[0, 1, 2]].data['x'] = th.zeros(3, 3) g.nodes[[0, 1, 2]].data['x'] = th.zeros(3, 3)
g.nodes[th.tensor([0, 1, 2])].data['x'] = th.zeros(3, 3) g.nodes[th.tensor([0, 1, 2])].data['x'] = th.zeros(3, 3)
############################################################################### ###############################################################################
# Assigning edge features is in a similar fashion to that of node features, # Assigning edge features is in a similar fashion to that of node features,
# except that one can also do it by specifying endpoints of the edges. # except that one can also do it by specifying endpoints of the edges.
g.edata['w'] = th.randn(9, 2) g.edata['w'] = th.randn(9, 2)
# access edge set with IDs in integer, list, or integer tensor # access edge set with IDs in integer, list, or integer tensor
g.edges[1].data['w'] = th.randn(1, 2) g.edges[1].data['w'] = th.randn(1, 2)
g.edges[[0, 1, 2]].data['w'] = th.zeros(3, 2) g.edges[[0, 1, 2]].data['w'] = th.zeros(3, 2)
g.edges[th.tensor([0, 1, 2])].data['w'] = th.zeros(3, 2) g.edges[th.tensor([0, 1, 2])].data['w'] = th.zeros(3, 2)
# one can also access the edges by giving endpoints # one can also access the edges by giving endpoints
g.edges[1, 0].data['w'] = th.ones(1, 2) # edge 1 -> 0 g.edges[1, 0].data['w'] = th.ones(1, 2) # edge 1 -> 0
g.edges[[1, 2, 3], [0, 0, 0]].data['w'] = th.ones(3, 2) # edges [1, 2, 3] -> 0 g.edges[[1, 2, 3], [0, 0, 0]].data['w'] = th.ones(3, 2) # edges [1, 2, 3] -> 0
############################################################################### ###############################################################################
# After assignments, each node/edge field will be associated with a scheme # After assignments, each node/edge field will be associated with a scheme
# containing the shape and data type (dtype) of its field value. # containing the shape and data type (dtype) of its field value.
print(g.node_attr_schemes()) print(g.node_attr_schemes())
g.ndata['x'] = th.zeros((10, 4)) g.ndata['x'] = th.zeros((10, 4))
print(g.node_attr_schemes()) print(g.node_attr_schemes())
############################################################################### ###############################################################################
# One can also remove node/edge states from the graph. This is particularly # One can also remove node/edge states from the graph. This is particularly
# useful to save memory during inference. # useful to save memory during inference.
g.ndata.pop('x') g.ndata.pop('x')
g.edata.pop('w') g.edata.pop('w')
############################################################################### ###############################################################################
# Multigraphs # Multigraphs
# ~~~~~~~~~~~ # ~~~~~~~~~~~
# Many graph applications need multi-edges. To enable this, construct :class:`DGLGraph` # Many graph applications need multi-edges. To enable this, construct :class:`DGLGraph`
# with ``multigraph=True``. # with ``multigraph=True``.
g_multi = dgl.DGLGraph(multigraph=True) g_multi = dgl.DGLGraph(multigraph=True)
g_multi.add_nodes(10) g_multi.add_nodes(10)
g_multi.ndata['x'] = th.randn(10, 2) g_multi.ndata['x'] = th.randn(10, 2)
g_multi.add_edges(list(range(1, 10)), 0) g_multi.add_edges(list(range(1, 10)), 0)
g_multi.add_edge(1, 0) # two edges on 1->0 g_multi.add_edge(1, 0) # two edges on 1->0
g_multi.edata['w'] = th.randn(10, 2) g_multi.edata['w'] = th.randn(10, 2)
g_multi.edges[1].data['w'] = th.zeros(1, 2) g_multi.edges[1].data['w'] = th.zeros(1, 2)
print(g_multi.edges()) print(g_multi.edges())
############################################################################### ###############################################################################
# An edge in multi-graph cannot be uniquely identified using its incident nodes # An edge in multi-graph cannot be uniquely identified using its incident nodes
# :math:`u` and :math:`v`; query their edge ids use ``edge_id`` interface. # :math:`u` and :math:`v`; query their edge ids use ``edge_id`` interface.
eid_10 = g_multi.edge_id(1, 0) eid_10 = g_multi.edge_id(1, 0)
g_multi.edges[eid_10].data['w'] = th.ones(len(eid_10), 2) g_multi.edges[eid_10].data['w'] = th.ones(len(eid_10), 2)
print(g_multi.edata['w']) print(g_multi.edata['w'])
############################################################################### ###############################################################################
# .. note:: # .. note::
# #
# * Nodes and edges can be added but not removed; we will support removal in # * Nodes and edges can be added but not removed; we will support removal in
# the future. # the future.
# * Updating a feature of different schemes raise error on indivdual node (or # * Updating a feature of different schemes raise error on indivdual node (or
# node subset). # node subset).
############################################################################### ###############################################################################
# Next steps # Next steps
# ---------- # ----------
# In the :doc:`next tutorial <3_pagerank>`, we will go through the # In the :doc:`next tutorial <3_pagerank>`, we will go through the
# DGL message passing interface by implementing PageRank. # DGL message passing interface by implementing PageRank.
""" """
.. currentmodule:: dgl .. currentmodule:: dgl
PageRank with DGL Message Passing PageRank with DGL Message Passing
================================= =================================
**Author**: `Minjie Wang <https://jermainewang.github.io/>`_, Quan Gan, Yu Gai, **Author**: `Minjie Wang <https://jermainewang.github.io/>`_, Quan Gan, Yu Gai,
Zheng Zhang Zheng Zhang
In this section we illustrate the usage of different levels of message In this section we illustrate the usage of different levels of message
passing API with PageRank on a small graph. In DGL, the message passing and passing API with PageRank on a small graph. In DGL, the message passing and
feature transformations are all **User-Defined Functions** (UDFs). feature transformations are all **User-Defined Functions** (UDFs).
The goal of this tutorial: to implement PageRank using DGL message passing The goal of this tutorial: to implement PageRank using DGL message passing
interface. interface.
""" """
############################################################################### ###############################################################################
# The PageRank Algorithm # The PageRank Algorithm
# ---------------------- # ----------------------
# In each iteration of PageRank, every node (web page) first scatters its # In each iteration of PageRank, every node (web page) first scatters its
# PageRank value uniformly to its downstream nodes. The new PageRank value of # PageRank value uniformly to its downstream nodes. The new PageRank value of
# each node is computed by aggregating the received PageRank values from its # each node is computed by aggregating the received PageRank values from its
# neighbors, which is then adjusted by the damping factor: # neighbors, which is then adjusted by the damping factor:
# #
# .. math:: # .. math::
# #
# PV(u) = \frac{1-d}{N} + d \times \sum_{v \in \mathcal{N}(u)} # PV(u) = \frac{1-d}{N} + d \times \sum_{v \in \mathcal{N}(u)}
# \frac{PV(v)}{D(v)} # \frac{PV(v)}{D(v)}
# #
# where :math:`N` is the number of nodes in the graph; :math:`D(v)` is the # where :math:`N` is the number of nodes in the graph; :math:`D(v)` is the
# out-degree of a node :math:`v`; and :math:`\mathcal{N}(u)` is the neighbor # out-degree of a node :math:`v`; and :math:`\mathcal{N}(u)` is the neighbor
# nodes. # nodes.
############################################################################### ###############################################################################
# A naive implementation # A naive implementation
# ---------------------- # ----------------------
# Let us first create a graph with 100 nodes with NetworkX and convert it to a # Let us first create a graph with 100 nodes with NetworkX and convert it to a
# :class:`DGLGraph`: # :class:`DGLGraph`:
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
import dgl import dgl
N = 100 # number of nodes N = 100 # number of nodes
DAMP = 0.85 # damping factor DAMP = 0.85 # damping factor
K = 10 # number of iterations K = 10 # number of iterations
g = nx.nx.erdos_renyi_graph(N, 0.1) g = nx.nx.erdos_renyi_graph(N, 0.1)
g = dgl.DGLGraph(g) g = dgl.DGLGraph(g)
nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]]) nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])
plt.show() plt.show()
############################################################################### ###############################################################################
# According to the algorithm, PageRank consists of two phases in a typical # According to the algorithm, PageRank consists of two phases in a typical
# scatter-gather pattern. We first initialize the PageRank value of each node # scatter-gather pattern. We first initialize the PageRank value of each node
# to :math:`\frac{1}{N}` and store each node's out-degree as a node feature: # to :math:`\frac{1}{N}` and store each node's out-degree as a node feature:
g.ndata['pv'] = torch.ones(N) / N g.ndata['pv'] = torch.ones(N) / N
g.ndata['deg'] = g.out_degrees(g.nodes()).float() g.ndata['deg'] = g.out_degrees(g.nodes()).float()
############################################################################### ###############################################################################
# We then define the message function, which divides every node's PageRank # We then define the message function, which divides every node's PageRank
# value by its out-degree and passes the result as message to its neighbors: # value by its out-degree and passes the result as message to its neighbors:
def pagerank_message_func(edges): def pagerank_message_func(edges):
return {'pv' : edges.src['pv'] / edges.src['deg']} return {'pv' : edges.src['pv'] / edges.src['deg']}
############################################################################### ###############################################################################
# In DGL, the message functions are expressed as **Edge UDFs**. Edge UDFs # In DGL, the message functions are expressed as **Edge UDFs**. Edge UDFs
# take in a single argument ``edges``. It has three members ``src``, ``dst``, # take in a single argument ``edges``. It has three members ``src``, ``dst``,
# and ``data`` for accessing source node features, destination node features, # and ``data`` for accessing source node features, destination node features,
# and edge features respectively. Here, the function computes messages only # and edge features respectively. Here, the function computes messages only
# from source node features. # from source node features.
# #
# Next, we define the reduce function, which removes and aggregates the # Next, we define the reduce function, which removes and aggregates the
# messages from its ``mailbox``, and computes its new PageRank value: # messages from its ``mailbox``, and computes its new PageRank value:
def pagerank_reduce_func(nodes): def pagerank_reduce_func(nodes):
msgs = torch.sum(nodes.mailbox['pv'], dim=1) msgs = torch.sum(nodes.mailbox['pv'], dim=1)
pv = (1 - DAMP) / N + DAMP * msgs pv = (1 - DAMP) / N + DAMP * msgs
return {'pv' : pv} return {'pv' : pv}
############################################################################### ###############################################################################
# The reduce functions are **Node UDFs**. Node UDFs have a single argument # The reduce functions are **Node UDFs**. Node UDFs have a single argument
# ``nodes``, which has two members ``data`` and ``mailbox``. ``data`` # ``nodes``, which has two members ``data`` and ``mailbox``. ``data``
# contains the node features while ``mailbox`` contains all incoming message # contains the node features while ``mailbox`` contains all incoming message
# features, stacked along the second dimension (hence the ``dim=1`` argument). # features, stacked along the second dimension (hence the ``dim=1`` argument).
# #
# The message UDF works on a batch of edges, whereas the reduce UDF works on # The message UDF works on a batch of edges, whereas the reduce UDF works on
# a batch of edges but outputs a batch of nodes. Their relationships are as # a batch of edges but outputs a batch of nodes. Their relationships are as
# follows: # follows:
# #
# .. image:: https://i.imgur.com/kIMiuFb.png # .. image:: https://i.imgur.com/kIMiuFb.png
# #
# We register the message function and reduce function, which will be called # We register the message function and reduce function, which will be called
# later by DGL. # later by DGL.
g.register_message_func(pagerank_message_func) g.register_message_func(pagerank_message_func)
g.register_reduce_func(pagerank_reduce_func) g.register_reduce_func(pagerank_reduce_func)
############################################################################### ###############################################################################
# The algorithm is then very straight-forward. Here is the code for one # The algorithm is then very straight-forward. Here is the code for one
# PageRank iteration: # PageRank iteration:
def pagerank_naive(g): def pagerank_naive(g):
# Phase #1: send out messages along all edges. # Phase #1: send out messages along all edges.
for u, v in zip(*g.edges()): for u, v in zip(*g.edges()):
g.send((u, v)) g.send((u, v))
# Phase #2: receive messages to compute new PageRank values. # Phase #2: receive messages to compute new PageRank values.
for v in g.nodes(): for v in g.nodes():
g.recv(v) g.recv(v)
############################################################################### ###############################################################################
# Improvement with batching semantics # Improvement with batching semantics
# ----------------------------------- # -----------------------------------
# The above code does not scale to large graph because it iterates over all # The above code does not scale to large graph because it iterates over all
# the nodes. DGL solves this by letting user compute on a *batch* of nodes or # the nodes. DGL solves this by letting user compute on a *batch* of nodes or
# edges. For example, the following codes trigger message and reduce functions # edges. For example, the following codes trigger message and reduce functions
# on multiple nodes and edges at once. # on multiple nodes and edges at once.
def pagerank_batch(g): def pagerank_batch(g):
g.send(g.edges()) g.send(g.edges())
g.recv(g.nodes()) g.recv(g.nodes())
############################################################################### ###############################################################################
# Note that we are still using the same reduce function ``pagerank_reduce_func``, # Note that we are still using the same reduce function ``pagerank_reduce_func``,
# where ``nodes.mailbox['pv']`` is a *single* tensor, stacking the incoming # where ``nodes.mailbox['pv']`` is a *single* tensor, stacking the incoming
# messages along the second dimension. # messages along the second dimension.
# #
# Naturally, one will wonder if this is even possible to perform reduce on all # Naturally, one will wonder if this is even possible to perform reduce on all
# nodes in parallel, since each node may have different number of incoming # nodes in parallel, since each node may have different number of incoming
# messages and one cannot really "stack" tensors of different lengths together. # messages and one cannot really "stack" tensors of different lengths together.
# In general, DGL solves the problem by grouping the nodes by the number of # In general, DGL solves the problem by grouping the nodes by the number of
# incoming messages, and calling the reduce function for each group. # incoming messages, and calling the reduce function for each group.
############################################################################### ###############################################################################
# More improvement with higher level APIs # More improvement with higher level APIs
# --------------------------------------- # ---------------------------------------
# DGL provides many routines that combines basic ``send`` and ``recv`` in # DGL provides many routines that combines basic ``send`` and ``recv`` in
# various ways. They are called **level-2 APIs**. For example, the PageRank # various ways. They are called **level-2 APIs**. For example, the PageRank
# example can be further simplified as follows: # example can be further simplified as follows:
def pagerank_level2(g): def pagerank_level2(g):
g.update_all() g.update_all()
############################################################################### ###############################################################################
# Besides ``update_all``, we also have ``pull``, ``push``, and ``send_and_recv`` # Besides ``update_all``, we also have ``pull``, ``push``, and ``send_and_recv``
# in this level-2 category. Please refer to the :doc:`API reference <../api/python/graph>` # in this level-2 category. Please refer to the :doc:`API reference <../../api/python/graph>`
# for more details. # for more details.
############################################################################### ###############################################################################
# Even more improvement with DGL builtin functions # Even more improvement with DGL builtin functions
# ------------------------------------------------ # ------------------------------------------------
# As some of the message and reduce functions are very commonly used, DGL also # As some of the message and reduce functions are very commonly used, DGL also
# provides **builtin functions**. For example, two builtin functions can be # provides **builtin functions**. For example, two builtin functions can be
# used in the PageRank example. # used in the PageRank example.
# #
# * :func:`dgl.function.copy_src(src, out) <function.copy_src>` # * :func:`dgl.function.copy_src(src, out) <function.copy_src>`
# is an edge UDF that computes the # is an edge UDF that computes the
# output using the source node feature data. User needs to specify the name of # output using the source node feature data. User needs to specify the name of
# the source feature data (``src``) and the output name (``out``). # the source feature data (``src``) and the output name (``out``).
# #
# * :func:`dgl.function.sum(msg, out) <function.sum>` is a node UDF # * :func:`dgl.function.sum(msg, out) <function.sum>` is a node UDF
# that sums the messages in # that sums the messages in
# the node's mailbox. User needs to specify the message name (``msg``) and the # the node's mailbox. User needs to specify the message name (``msg``) and the
# output name (``out``). # output name (``out``).
# #
# For example, the PageRank example can be rewritten as following: # For example, the PageRank example can be rewritten as following:
import dgl.function as fn import dgl.function as fn
def pagerank_builtin(g): def pagerank_builtin(g):
g.ndata['pv'] = g.ndata['pv'] / g.ndata['deg'] g.ndata['pv'] = g.ndata['pv'] / g.ndata['deg']
g.update_all(message_func=fn.copy_src(src='pv', out='m'), g.update_all(message_func=fn.copy_src(src='pv', out='m'),
reduce_func=fn.sum(msg='m',out='m_sum')) reduce_func=fn.sum(msg='m',out='m_sum'))
g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['m_sum'] g.ndata['pv'] = (1 - DAMP) / N + DAMP * g.ndata['m_sum']
############################################################################### ###############################################################################
# Here, we directly provide the UDFs to the :func:`update_all <DGLGraph.update_all>` # Here, we directly provide the UDFs to the :func:`update_all <DGLGraph.update_all>`
# as its arguments. # as its arguments.
# This will override the previously registered UDFs. # This will override the previously registered UDFs.
# #
# In addition to cleaner code, using builtin functions also gives DGL the # In addition to cleaner code, using builtin functions also gives DGL the
# opportunity to fuse operations together, resulting in faster execution. For # opportunity to fuse operations together, resulting in faster execution. For
# example, DGL will fuse the ``copy_src`` message function and ``sum`` reduce # example, DGL will fuse the ``copy_src`` message function and ``sum`` reduce
# function into one sparse matrix-vector (spMV) multiplication. # function into one sparse matrix-vector (spMV) multiplication.
# #
# `This section <spmv_>`_ describes why spMV can speed up the scatter-gather # `This section <spmv_>`_ describes why spMV can speed up the scatter-gather
# phase in PageRank. For more details about the builtin functions in DGL, # phase in PageRank. For more details about the builtin functions in DGL,
# please read the :doc:`API reference <../api/python/function>`. # please read the :doc:`API reference <../../api/python/function>`.
# #
# You can also download and run the codes to feel the difference. # You can also download and run the codes to feel the difference.
for k in range(K): for k in range(K):
# Uncomment the corresponding line to select different version. # Uncomment the corresponding line to select different version.
# pagerank_naive(g) # pagerank_naive(g)
# pagerank_batch(g) # pagerank_batch(g)
# pagerank_level2(g) # pagerank_level2(g)
pagerank_builtin(g) pagerank_builtin(g)
print(g.ndata['pv']) print(g.ndata['pv'])
############################################################################### ###############################################################################
# .. _spmv: # .. _spmv:
# #
# Using spMV for PageRank # Using spMV for PageRank
# ----------------------- # -----------------------
# Using builtin functions allows DGL to understand the semantics of UDFs and # Using builtin functions allows DGL to understand the semantics of UDFs and
# thus allows more efficient implementation for you. For example, in the case # thus allows more efficient implementation for you. For example, in the case
# of PageRank, one common trick to accelerate it is using its linear algebra # of PageRank, one common trick to accelerate it is using its linear algebra
# form. # form.
# #
# .. math:: # .. math::
# #
# \mathbf{R}^{k} = \frac{1-d}{N} \mathbf{1} + d \mathbf{A}*\mathbf{R}^{k-1} # \mathbf{R}^{k} = \frac{1-d}{N} \mathbf{1} + d \mathbf{A}*\mathbf{R}^{k-1}
# #
# Here, :math:`\mathbf{R}^k` is the vector of the PageRank values of all nodes # Here, :math:`\mathbf{R}^k` is the vector of the PageRank values of all nodes
# at iteration :math:`k`; :math:`\mathbf{A}` is the sparse adjacency matrix # at iteration :math:`k`; :math:`\mathbf{A}` is the sparse adjacency matrix
# of the graph. # of the graph.
# Computing this equation is quite efficient because there exists efficient # Computing this equation is quite efficient because there exists efficient
# GPU kernel for the *sparse-matrix-vector-multiplication* (spMV). DGL # GPU kernel for the *sparse-matrix-vector-multiplication* (spMV). DGL
# detects whether such optimization is available through the builtin # detects whether such optimization is available through the builtin
# functions. If the certain combination of builtins can be mapped to a spMV # functions. If the certain combination of builtins can be mapped to a spMV
# kernel (e.g. the pagerank example), DGL will use it automatically. As a # kernel (e.g. the pagerank example), DGL will use it automatically. As a
# result, *we recommend using builtin functions whenever it is possible*. # result, *we recommend using builtin functions whenever it is possible*.
############################################################################### ###############################################################################
# Next steps # Next steps
# ---------- # ----------
# Check out :doc:`GCN <models/1_gcn>` and :doc:`Capsule <models/2_capsule>` # Check out :doc:`GCN <../models/1_gnn/1_gcn>` and :doc:`Capsule <../models/4_old_wines/2_capsule>`
# for more model implemenetations in DGL. # for more model implemenetations in DGL.
Basic Tutorials Basic Tutorials
=============== ===============
These tutorials conver the basics of DGL. These tutorials cover the basics of DGL.
""" """
.. _model-gcn: .. _model-gcn:
Graph Convolutional Network Graph Convolutional Network
==================================== ====================================
**Author:** `Qi Huang <https://github.com/HQ01>`_, `Minjie Wang <https://jermainewang.github.io/>`_, **Author:** `Qi Huang <https://github.com/HQ01>`_, `Minjie Wang <https://jermainewang.github.io/>`_,
Yu Gai, Quan Gan, Zheng Zhang Yu Gai, Quan Gan, Zheng Zhang
This is a gentle introduction of using DGL to implement Graph Convolutional This is a gentle introduction of using DGL to implement Graph Convolutional
Networks (Kipf & Welling et al., `Semi-Supervised Classificaton with Graph Networks (Kipf & Welling et al., `Semi-Supervised Classificaton with Graph
Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We build upon Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We build upon
the :doc:`earlier tutorial <../3_pagerank>` on DGLGraph and demonstrate the :doc:`earlier tutorial <../../basics/3_pagerank>` on DGLGraph and demonstrate
how DGL combines graph with deep neural network and learn structural representations. how DGL combines graph with deep neural network and learn structural representations.
""" """
############################################################################### ###############################################################################
# Model Overview # Model Overview
# ------------------------------------------ # ------------------------------------------
# GCN from the perspective of message passing # GCN from the perspective of message passing
# ``````````````````````````````````````````````` # ```````````````````````````````````````````````
# We describe a layer of graph convolutional neural network from a message # We describe a layer of graph convolutional neural network from a message
# passing perspective; the math can be found `here <math_>`_. # passing perspective; the math can be found `here <math_>`_.
# It boils down to the following step, for each node :math:`u`: # It boils down to the following step, for each node :math:`u`:
# #
# 1) Aggregate neighbors' representations :math:`h_{v}` to produce an # 1) Aggregate neighbors' representations :math:`h_{v}` to produce an
# intermediate representation :math:`\hat{h}_u`. 2) Transform the aggregated # intermediate representation :math:`\hat{h}_u`. 2) Transform the aggregated
# representation :math:`\hat{h}_{u}` with a linear projection followed by a # representation :math:`\hat{h}_{u}` with a linear projection followed by a
# non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`. # non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`.
# #
# We will implement step 1 with DGL message passing, and step 2 with the # We will implement step 1 with DGL message passing, and step 2 with the
# ``apply_nodes`` method, whose node UDF will be a PyTorch ``nn.Module``. # ``apply_nodes`` method, whose node UDF will be a PyTorch ``nn.Module``.
# #
# GCN implementation with DGL # GCN implementation with DGL
# `````````````````````````````````````````` # ``````````````````````````````````````````
# We first define the message and reduce function as usual. Since the # We first define the message and reduce function as usual. Since the
# aggregation on a node :math:`u` only involves summing over the neighbors' # aggregation on a node :math:`u` only involves summing over the neighbors'
# representations :math:`h_v`, we can simply use builtin functions: # representations :math:`h_v`, we can simply use builtin functions:
import dgl import dgl
import dgl.function as fn import dgl.function as fn
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
gcn_msg = fn.copy_src(src='h', out='m') gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h') gcn_reduce = fn.sum(msg='m', out='h')
############################################################################### ###############################################################################
# We then define the node UDF for ``apply_nodes``, which is a fully-connected layer: # We then define the node UDF for ``apply_nodes``, which is a fully-connected layer:
class NodeApplyModule(nn.Module): class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation): def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__() super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats) self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation self.activation = activation
def forward(self, node): def forward(self, node):
h = self.linear(node.data['h']) h = self.linear(node.data['h'])
h = self.activation(h) h = self.activation(h)
return {'h' : h} return {'h' : h}
############################################################################### ###############################################################################
# We then proceed to define the GCN module. A GCN layer essentially performs # We then proceed to define the GCN module. A GCN layer essentially performs
# message passing on all the nodes then applies the `NodeApplyModule`. Note # message passing on all the nodes then applies the `NodeApplyModule`. Note
# that we omitted the dropout in the paper for simplicity. # that we omitted the dropout in the paper for simplicity.
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation): def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__() super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation) self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature): def forward(self, g, feature):
g.ndata['h'] = feature g.ndata['h'] = feature
g.update_all(gcn_msg, gcn_reduce) g.update_all(gcn_msg, gcn_reduce)
g.apply_nodes(func=self.apply_mod) g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h') return g.ndata.pop('h')
############################################################################### ###############################################################################
# The forward function is essentially the same as any other commonly seen NNs # The forward function is essentially the same as any other commonly seen NNs
# model in PyTorch. We can initialize GCN like any ``nn.Module``. For example, # model in PyTorch. We can initialize GCN like any ``nn.Module``. For example,
# let's define a simple neural network consisting of two GCN layers. Suppose we # let's define a simple neural network consisting of two GCN layers. Suppose we
# are training the classifier for the cora dataset (the input feature size is # are training the classifier for the cora dataset (the input feature size is
# 1433 and the number of classes is 7). # 1433 and the number of classes is 7).
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.gcn1 = GCN(1433, 16, F.relu) self.gcn1 = GCN(1433, 16, F.relu)
self.gcn2 = GCN(16, 7, F.relu) self.gcn2 = GCN(16, 7, F.relu)
def forward(self, g, features): def forward(self, g, features):
x = self.gcn1(g, features) x = self.gcn1(g, features)
x = self.gcn2(g, x) x = self.gcn2(g, x)
return x return x
net = Net() net = Net()
print(net) print(net)
############################################################################### ###############################################################################
# We load the cora dataset using DGL's built-in data module. # We load the cora dataset using DGL's built-in data module.
from dgl.data import citation_graph as citegrh from dgl.data import citation_graph as citegrh
def load_cora_data(): def load_cora_data():
data = citegrh.load_cora() data = citegrh.load_cora()
features = th.FloatTensor(data.features) features = th.FloatTensor(data.features)
labels = th.LongTensor(data.labels) labels = th.LongTensor(data.labels)
mask = th.ByteTensor(data.train_mask) mask = th.ByteTensor(data.train_mask)
g = DGLGraph(data.graph) g = DGLGraph(data.graph)
return g, features, labels, mask return g, features, labels, mask
############################################################################### ###############################################################################
# We then train the network as follows: # We then train the network as follows:
import time import time
import numpy as np import numpy as np
g, features, labels, mask = load_cora_data() g, features, labels, mask = load_cora_data()
optimizer = th.optim.Adam(net.parameters(), lr=1e-3) optimizer = th.optim.Adam(net.parameters(), lr=1e-3)
dur = [] dur = []
for epoch in range(30): for epoch in range(30):
if epoch >=3: if epoch >=3:
t0 = time.time() t0 = time.time()
logits = net(g, features) logits = net(g, features)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask]) loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if epoch >=3: if epoch >=3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format( print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur))) epoch, loss.item(), np.mean(dur)))
############################################################################### ###############################################################################
# .. _math: # .. _math:
# #
# GCN in one formula # GCN in one formula
# ------------------ # ------------------
# Mathematically, the GCN model follows this formula: # Mathematically, the GCN model follows this formula:
# #
# :math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})` # :math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})`
# #
# Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network, # Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network,
# :math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for # :math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for
# this layer. :math:`D` and :math:`A`, as commonly seen, represent degree # this layer. :math:`D` and :math:`A`, as commonly seen, represent degree
# matrix and adjacency matrix, respectively. The ~ is a renormalization trick # matrix and adjacency matrix, respectively. The ~ is a renormalization trick
# in which we add a self-connection to each node of the graph, and build the # in which we add a self-connection to each node of the graph, and build the
# corresponding degree and adjacency matrix. The shape of the input # corresponding degree and adjacency matrix. The shape of the input
# :math:`H^{(0)}` is :math:`N \times D`, where :math:`N` is the number of nodes # :math:`H^{(0)}` is :math:`N \times D`, where :math:`N` is the number of nodes
# and :math:`D` is the number of input features. We can chain up multiple # and :math:`D` is the number of input features. We can chain up multiple
# layers as such to produce a node-level representation output with shape # layers as such to produce a node-level representation output with shape
# :math`N \times F`, where :math:`F` is the dimension of the output node # :math`N \times F`, where :math:`F` is the dimension of the output node
# feature vector. # feature vector.
# #
# The equation can be efficiently implemented using sparse matrix # The equation can be efficiently implemented using sparse matrix
# multiplication kernels (such as Kipf's # multiplication kernels (such as Kipf's
# `pygcn <https://github.com/tkipf/pygcn>`_ code). The above DGL implementation # `pygcn <https://github.com/tkipf/pygcn>`_ code). The above DGL implementation
# in fact has already used this trick due to the use of builtin functions. To # in fact has already used this trick due to the use of builtin functions. To
# understand what is under the hood, please read our tutorial on :doc:`PageRank <../3_pagerank>`. # understand what is under the hood, please read our tutorial on :doc:`PageRank <../../basics/3_pagerank>`.
""" """
.. _model-rgcn: .. _model-rgcn:
Relational Graph Convolutional Network Tutorial Relational Graph Convolutional Network Tutorial
================================================ ================================================
**Author:** Lingfan Yu, Mufei Li, Zheng Zhang **Author:** Lingfan Yu, Mufei Li, Zheng Zhang
The vanilla Graph Convolutional Network (GCN) The vanilla Graph Convolutional Network (GCN)
(`paper <https://arxiv.org/pdf/1609.02907.pdf>`_, (`paper <https://arxiv.org/pdf/1609.02907.pdf>`_,
`DGL tutorial <http://doc.dgl.ai/tutorials/index.html>`_) exploits `DGL tutorial <http://doc.dgl.ai/tutorials/index.html>`_) exploits
structural information of the dataset (i.e. the graph connectivity) to structural information of the dataset (i.e. the graph connectivity) to
improve the extraction of node representations. Graph edges are left as improve the extraction of node representations. Graph edges are left as
untyped. untyped.
A knowledge graph is made up by a collection of triples of the form A knowledge graph is made up by a collection of triples of the form
(subject, relation, object). Edges thus encode important information and (subject, relation, object). Edges thus encode important information and
have their own embeddings to be learned. Furthermore, there may exist have their own embeddings to be learned. Furthermore, there may exist
multiple edges among any given pair. multiple edges among any given pair.
A recent model Relational-GCN (R-GCN) from the paper A recent model Relational-GCN (R-GCN) from the paper
`Modeling Relational Data with Graph Convolutional `Modeling Relational Data with Graph Convolutional
Networks <https://arxiv.org/pdf/1703.06103.pdf>`_ is one effort to Networks <https://arxiv.org/pdf/1703.06103.pdf>`_ is one effort to
generalize GCN to handle different relations between entities in knowledge generalize GCN to handle different relations between entities in knowledge
base. This tutorial shows how to implement R-GCN with DGL. base. This tutorial shows how to implement R-GCN with DGL.
""" """
############################################################################### ###############################################################################
# R-GCN: a brief introduction # R-GCN: a brief introduction
# --------------------------- # ---------------------------
# In *statistical relational learning* (SRL), there are two fundamental # In *statistical relational learning* (SRL), there are two fundamental
# tasks: # tasks:
# #
# - **Entity classification**, i.e., assign types and categorical # - **Entity classification**, i.e., assign types and categorical
# properties to entities. # properties to entities.
# - **Link prediction**, i.e., recover missing triples. # - **Link prediction**, i.e., recover missing triples.
# #
# In both cases, missing information are expected to be recovered from # In both cases, missing information are expected to be recovered from
# neighborhood structure of the graph. Here is the example from the R-GCN # neighborhood structure of the graph. Here is the example from the R-GCN
# paper: # paper:
# #
# "Knowing that Mikhail Baryshnikov was educated at the Vaganova Academy # "Knowing that Mikhail Baryshnikov was educated at the Vaganova Academy
# implies both that Mikhail Baryshnikov should have the label person, and # implies both that Mikhail Baryshnikov should have the label person, and
# that the triple (Mikhail Baryshnikov, lived in, Russia) must belong to the # that the triple (Mikhail Baryshnikov, lived in, Russia) must belong to the
# knowledge graph." # knowledge graph."
# #
# R-GCN solves these two problems using a common graph convolutional network # R-GCN solves these two problems using a common graph convolutional network
# extended with multi-edge encoding to compute embedding of the entities, but # extended with multi-edge encoding to compute embedding of the entities, but
# with different downstream processing: # with different downstream processing:
# #
# - Entity classification is done by attaching a softmax classifier at the # - Entity classification is done by attaching a softmax classifier at the
# final embedding of an entity (node). Training is through loss of standard # final embedding of an entity (node). Training is through loss of standard
# cross-entropy. # cross-entropy.
# - Link prediction is done by reconstructing an edge with an autoencoder # - Link prediction is done by reconstructing an edge with an autoencoder
# architecture, using a parameterized score function. Training uses negative # architecture, using a parameterized score function. Training uses negative
# sampling. # sampling.
# #
# This tutorial will focus on the first task to show how to generate entity # This tutorial will focus on the first task to show how to generate entity
# representation. `Complete # representation. `Complete
# code <https://github.com/jermainewang/dgl/tree/rgcn/examples/pytorch/rgcn>`_ # code <https://github.com/jermainewang/dgl/tree/rgcn/examples/pytorch/rgcn>`_
# for both tasks can be found in DGL's github repository. # for both tasks can be found in DGL's github repository.
# #
# Key ideas of R-GCN # Key ideas of R-GCN
# ------------------- # -------------------
# Recall that in GCN, the hidden representation for each node :math:`i` at # Recall that in GCN, the hidden representation for each node :math:`i` at
# :math:`(l+1)^{th}` layer is computed by: # :math:`(l+1)^{th}` layer is computed by:
# #
# .. math:: h_i^{l+1} = \sigma\left(\sum_{j\in N_i}\frac{1}{c_i} W^{(l)} h_j^{(l)}\right)~~~~~~~~~~(1)\\ # .. math:: h_i^{l+1} = \sigma\left(\sum_{j\in N_i}\frac{1}{c_i} W^{(l)} h_j^{(l)}\right)~~~~~~~~~~(1)\\
# #
# where :math:`c_i` is a normalization constant. # where :math:`c_i` is a normalization constant.
# #
# The key difference between R-GCN and GCN is that in R-GCN, edges can # The key difference between R-GCN and GCN is that in R-GCN, edges can
# represent different relations. In GCN, weight :math:`W^{(l)}` in equation # represent different relations. In GCN, weight :math:`W^{(l)}` in equation
# :math:`(1)` is shared by all edges in layer :math:`l`. In contrast, in # :math:`(1)` is shared by all edges in layer :math:`l`. In contrast, in
# R-GCN, different edge types use different weights and only edges of the # R-GCN, different edge types use different weights and only edges of the
# same relation type :math:`r` are associated with the same projection weight # same relation type :math:`r` are associated with the same projection weight
# :math:`W_r^{(l)}`. # :math:`W_r^{(l)}`.
# #
# So the hidden representation of entities in :math:`(l+1)^{th}` layer in # So the hidden representation of entities in :math:`(l+1)^{th}` layer in
# R-GCN can be formulated as the following equation: # R-GCN can be formulated as the following equation:
# #
# .. math:: h_i^{l+1} = \sigma\left(W_0^{(l)}h_i^{(l)}+\sum_{r\in R}\sum_{j\in N_i^r}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}\right)~~~~~~~~~~(2)\\ # .. math:: h_i^{l+1} = \sigma\left(W_0^{(l)}h_i^{(l)}+\sum_{r\in R}\sum_{j\in N_i^r}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}\right)~~~~~~~~~~(2)\\
# #
# where :math:`N_i^r` denotes the set of neighbor indices of node :math:`i` # where :math:`N_i^r` denotes the set of neighbor indices of node :math:`i`
# under relation :math:`r\in R` and :math:`c_{i,r}` is a normalization # under relation :math:`r\in R` and :math:`c_{i,r}` is a normalization
# constant. In entity classification, the R-GCN paper uses # constant. In entity classification, the R-GCN paper uses
# :math:`c_{i,r}=|N_i^r|`. # :math:`c_{i,r}=|N_i^r|`.
# #
# The problem of applying the above equation directly is rapid growth of # The problem of applying the above equation directly is rapid growth of
# number of parameters, especially with highly multi-relational data. In # number of parameters, especially with highly multi-relational data. In
# order to reduce model parameter size and prevent overfitting, the original # order to reduce model parameter size and prevent overfitting, the original
# paper proposes to use basis decomposition: # paper proposes to use basis decomposition:
# #
# .. math:: W_r^{(l)}=\sum\limits_{b=1}^B a_{rb}^{(l)}V_b^{(l)}~~~~~~~~~~(3)\\ # .. math:: W_r^{(l)}=\sum\limits_{b=1}^B a_{rb}^{(l)}V_b^{(l)}~~~~~~~~~~(3)\\
# #
# Therefore, the weight :math:`W_r^{(l)}` is a linear combination of basis # Therefore, the weight :math:`W_r^{(l)}` is a linear combination of basis
# transformation :math:`V_b^{(l)}` with coefficients :math:`a_{rb}^{(l)}`. # transformation :math:`V_b^{(l)}` with coefficients :math:`a_{rb}^{(l)}`.
# The number of bases :math:`B` is much smaller than the number of relations # The number of bases :math:`B` is much smaller than the number of relations
# in the knowledge base. # in the knowledge base.
# #
# .. note:: # .. note::
# Another weight regularization, block-decomposition, is implemented in # Another weight regularization, block-decomposition, is implemented in
# the `link prediction <link-prediction_>`_. # the `link prediction <link-prediction_>`_.
# #
# Implement R-GCN in DGL # Implement R-GCN in DGL
# ---------------------- # ----------------------
# #
# An R-GCN model is composed of several R-GCN layers. The first R-GCN layer # An R-GCN model is composed of several R-GCN layers. The first R-GCN layer
# also serves as input layer and takes in features (e.g. description texts) # also serves as input layer and takes in features (e.g. description texts)
# associated with node entity and project to hidden space. In this tutorial, # associated with node entity and project to hidden space. In this tutorial,
# we only use entity id as entity feature. # we only use entity id as entity feature.
# #
# R-GCN Layers # R-GCN Layers
# ~~~~~~~~~~~~ # ~~~~~~~~~~~~
# #
# For each node, an R-GCN layer performs the following steps: # For each node, an R-GCN layer performs the following steps:
# #
# - Compute outgoing message using node representation and weight matrix # - Compute outgoing message using node representation and weight matrix
# associated with the edge type (message function) # associated with the edge type (message function)
# - Aggregate incoming messages and generate new node representations (reduce # - Aggregate incoming messages and generate new node representations (reduce
# and apply function) # and apply function)
# #
# The following is the definition of an R-GCN hidden layer. # The following is the definition of an R-GCN hidden layer.
# #
# .. note:: # .. note::
# Each relation type is associated with a different weight. Therefore, # Each relation type is associated with a different weight. Therefore,
# the full weight matrix has three dimensions: relation, input_feature, # the full weight matrix has three dimensions: relation, input_feature,
# output_feature. # output_feature.
# #
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl import DGLGraph from dgl import DGLGraph
import dgl.function as fn import dgl.function as fn
from functools import partial from functools import partial
class RGCNLayer(nn.Module): class RGCNLayer(nn.Module):
def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None, def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
activation=None, is_input_layer=False): activation=None, is_input_layer=False):
super(RGCNLayer, self).__init__() super(RGCNLayer, self).__init__()
self.in_feat = in_feat self.in_feat = in_feat
self.out_feat = out_feat self.out_feat = out_feat
self.num_rels = num_rels self.num_rels = num_rels
self.num_bases = num_bases self.num_bases = num_bases
self.bias = bias self.bias = bias
self.activation = activation self.activation = activation
self.is_input_layer = is_input_layer self.is_input_layer = is_input_layer
# sanity check # sanity check
if self.num_bases <= 0 or self.num_bases > self.num_rels: if self.num_bases <= 0 or self.num_bases > self.num_rels:
self.num_bases = self.num_rels self.num_bases = self.num_rels
# weight bases in equation (3) # weight bases in equation (3)
self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat, self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat,
self.out_feat)) self.out_feat))
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# linear combination coefficients in equation (3) # linear combination coefficients in equation (3)
self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases)) self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases))
# add bias # add bias
if self.bias: if self.bias:
self.bias = nn.Parameter(torch.Tensor(out_feat)) self.bias = nn.Parameter(torch.Tensor(out_feat))
# init trainable parameters # init trainable parameters
nn.init.xavier_uniform_(self.weight, nn.init.xavier_uniform_(self.weight,
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp, nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
if self.bias: if self.bias:
nn.init.xavier_uniform_(self.bias, nn.init.xavier_uniform_(self.bias,
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
def forward(self, g): def forward(self, g):
if self.num_bases < self.num_rels: if self.num_bases < self.num_rels:
# generate all weights from bases (equation (3)) # generate all weights from bases (equation (3))
weight = self.weight.view(self.in_feat, self.num_bases, self.out_feat) weight = self.weight.view(self.in_feat, self.num_bases, self.out_feat)
weight = torch.matmul(self.w_comp, weight).view(self.num_rels, weight = torch.matmul(self.w_comp, weight).view(self.num_rels,
self.in_feat, self.out_feat) self.in_feat, self.out_feat)
else: else:
weight = self.weight weight = self.weight
if self.is_input_layer: if self.is_input_layer:
def message_func(edges): def message_func(edges):
# for input layer, matrix multiply can be converted to be # for input layer, matrix multiply can be converted to be
# an embedding lookup using source node id # an embedding lookup using source node id
embed = weight.view(-1, self.out_feat) embed = weight.view(-1, self.out_feat)
index = edges.data['rel_type'] * self.in_feat + edges.src['id'] index = edges.data['rel_type'] * self.in_feat + edges.src['id']
return {'msg': embed[index] * edges.data['norm']} return {'msg': embed[index] * edges.data['norm']}
else: else:
def message_func(edges): def message_func(edges):
w = weight[edges.data['rel_type']] w = weight[edges.data['rel_type']]
msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze() msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
msg = msg * edges.data['norm'] msg = msg * edges.data['norm']
return {'msg': msg} return {'msg': msg}
def apply_func(nodes): def apply_func(nodes):
h = nodes.data['h'] h = nodes.data['h']
if self.bias: if self.bias:
h = h + self.bias h = h + self.bias
if self.activation: if self.activation:
h = self.activation(h) h = self.activation(h)
return {'h': h} return {'h': h}
g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func) g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)
############################################################################### ###############################################################################
# Define full R-GCN model # Define full R-GCN model
# ~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~
class Model(nn.Module): class Model(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels, def __init__(self, num_nodes, h_dim, out_dim, num_rels,
num_bases=-1, num_hidden_layers=1): num_bases=-1, num_hidden_layers=1):
super(Model, self).__init__() super(Model, self).__init__()
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.h_dim = h_dim self.h_dim = h_dim
self.out_dim = out_dim self.out_dim = out_dim
self.num_rels = num_rels self.num_rels = num_rels
self.num_bases = num_bases self.num_bases = num_bases
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
# create rgcn layers # create rgcn layers
self.build_model() self.build_model()
# create initial features # create initial features
self.features = self.create_features() self.features = self.create_features()
def build_model(self): def build_model(self):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# input to hidden # input to hidden
i2h = self.build_input_layer() i2h = self.build_input_layer()
self.layers.append(i2h) self.layers.append(i2h)
# hidden to hidden # hidden to hidden
for idx in range(self.num_hidden_layers): for idx in range(self.num_hidden_layers):
h2h = self.build_hidden_layer(idx) h2h = self.build_hidden_layer(idx)
self.layers.append(h2h) self.layers.append(h2h)
# hidden to output # hidden to output
h2o = self.build_output_layer() h2o = self.build_output_layer()
self.layers.append(h2o) self.layers.append(h2o)
# initialize feature for each node # initialize feature for each node
def create_features(self): def create_features(self):
features = torch.arange(self.num_nodes) features = torch.arange(self.num_nodes)
return features return features
def build_input_layer(self): def build_input_layer(self):
return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases, return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu, is_input_layer=True) activation=F.relu, is_input_layer=True)
def build_hidden_layer(self): def build_hidden_layer(self):
return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases, return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,
activation=F.relu) activation=F.relu)
def build_output_layer(self): def build_output_layer(self):
return RGCNLayer(self.h_dim, self.out_dim, self.num_rels, self.num_bases, return RGCNLayer(self.h_dim, self.out_dim, self.num_rels, self.num_bases,
activation=partial(F.softmax, dim=1)) activation=partial(F.softmax, dim=1))
def forward(self, g): def forward(self, g):
if self.features is not None: if self.features is not None:
g.ndata['id'] = self.features g.ndata['id'] = self.features
for layer in self.layers: for layer in self.layers:
layer(g) layer(g)
return g.ndata.pop('h') return g.ndata.pop('h')
############################################################################### ###############################################################################
# Handle dataset # Handle dataset
# ~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~
# In this tutorial, we use AIFB dataset from R-GCN paper: # In this tutorial, we use AIFB dataset from R-GCN paper:
# load graph data # load graph data
from dgl.contrib.data import load_data from dgl.contrib.data import load_data
import numpy as np import numpy as np
data = load_data(dataset='aifb') data = load_data(dataset='aifb')
num_nodes = data.num_nodes num_nodes = data.num_nodes
num_rels = data.num_rels num_rels = data.num_rels
num_classes = data.num_classes num_classes = data.num_classes
labels = data.labels labels = data.labels
train_idx = data.train_idx train_idx = data.train_idx
# split training and validation set # split training and validation set
val_idx = train_idx[:len(train_idx) // 5] val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:] train_idx = train_idx[len(train_idx) // 5:]
# edge type and normalization factor # edge type and normalization factor
edge_type = torch.from_numpy(data.edge_type) edge_type = torch.from_numpy(data.edge_type)
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1) edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1)
labels = torch.from_numpy(labels).view(-1) labels = torch.from_numpy(labels).view(-1)
############################################################################### ###############################################################################
# Create graph and model # Create graph and model
# ~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~
# configurations # configurations
n_hidden = 16 # number of hidden units n_hidden = 16 # number of hidden units
n_bases = -1 # use number of relations as number of bases n_bases = -1 # use number of relations as number of bases
n_hidden_layers = 0 # use 1 input layer, 1 output layer, no hidden layer n_hidden_layers = 0 # use 1 input layer, 1 output layer, no hidden layer
n_epochs = 25 # epochs to train n_epochs = 25 # epochs to train
lr = 0.01 # learning rate lr = 0.01 # learning rate
l2norm = 0 # L2 norm coefficient l2norm = 0 # L2 norm coefficient
# create graph # create graph
g = DGLGraph() g = DGLGraph()
g.add_nodes(num_nodes) g.add_nodes(num_nodes)
g.add_edges(data.edge_src, data.edge_dst) g.add_edges(data.edge_src, data.edge_dst)
g.edata.update({'rel_type': edge_type, 'norm': edge_norm}) g.edata.update({'rel_type': edge_type, 'norm': edge_norm})
# create model # create model
model = Model(len(g), model = Model(len(g),
n_hidden, n_hidden,
num_classes, num_classes,
num_rels, num_rels,
num_bases=n_bases, num_bases=n_bases,
num_hidden_layers=n_hidden_layers) num_hidden_layers=n_hidden_layers)
############################################################################### ###############################################################################
# Training loop # Training loop
# ~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~
# optimizer # optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm) optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm)
print("start training...") print("start training...")
model.train() model.train()
for epoch in range(n_epochs): for epoch in range(n_epochs):
optimizer.zero_grad() optimizer.zero_grad()
logits = model.forward(g) logits = model.forward(g)
loss = F.cross_entropy(logits[train_idx], labels[train_idx]) loss = F.cross_entropy(logits[train_idx], labels[train_idx])
loss.backward() loss.backward()
optimizer.step() optimizer.step()
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]) train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx])
train_acc = train_acc.item() / len(train_idx) train_acc = train_acc.item() / len(train_idx)
val_loss = F.cross_entropy(logits[val_idx], labels[val_idx]) val_loss = F.cross_entropy(logits[val_idx], labels[val_idx])
val_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]) val_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx])
val_acc = val_acc.item() / len(val_idx) val_acc = val_acc.item() / len(val_idx)
print("Epoch {:05d} | ".format(epoch) + print("Epoch {:05d} | ".format(epoch) +
"Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format( "Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
train_acc, loss.item()) + train_acc, loss.item()) +
"Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format( "Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
val_acc, val_loss.item())) val_acc, val_loss.item()))
############################################################################### ###############################################################################
# .. _link-prediction: # .. _link-prediction:
# #
# The second task: Link prediction # The second task: Link prediction
# -------------------------------- # --------------------------------
# So far, we have seen how to use DGL to implement entity classification with # So far, we have seen how to use DGL to implement entity classification with
# R-GCN model. In the knowledge base setting, representation generated by # R-GCN model. In the knowledge base setting, representation generated by
# R-GCN can be further used to uncover potential relations between nodes. In # R-GCN can be further used to uncover potential relations between nodes. In
# R-GCN paper, authors feed the entity representations generated by R-GCN # R-GCN paper, authors feed the entity representations generated by R-GCN
# into the `DistMult <https://arxiv.org/pdf/1412.6575.pdf>`_ prediction model # into the `DistMult <https://arxiv.org/pdf/1412.6575.pdf>`_ prediction model
# to predict possible relations. # to predict possible relations.
# #
# The implementation is similar to the above but with an extra DistMult layer # The implementation is similar to the above but with an extra DistMult layer
# stacked on top of the R-GCN layers. You may find the complete # stacked on top of the R-GCN layers. You may find the complete
# implementation of link prediction with R-GCN in our `example # implementation of link prediction with R-GCN in our `example
# code <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/rgcn/link_predict.py>`_. # code <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/rgcn/link_predict.py>`_.
""" """
.. _model-line-graph: .. _model-line-graph:
Line Graph Neural Network Line Graph Neural Network
========================= =========================
**Author**: `Qi Huang <https://github.com/HQ01>`_, Yu Gai, **Author**: `Qi Huang <https://github.com/HQ01>`_, Yu Gai,
`Minjie Wang <https://jermainewang.github.io/>`_, Zheng Zhang `Minjie Wang <https://jermainewang.github.io/>`_, Zheng Zhang
""" """
########################################################################################### ###########################################################################################
# #
# In :doc:`GCN <1_gcn>` , we demonstrate how to classify nodes on an input # In :doc:`GCN <1_gcn>` , we demonstrate how to classify nodes on an input
# graph in a semi-supervised setting, using graph convolutional neural network # graph in a semi-supervised setting, using graph convolutional neural network
# as embedding mechanism for graph features. # as embedding mechanism for graph features.
# In this tutorial, we shift our focus to community detection problem. The # In this tutorial, we shift our focus to community detection problem. The
# task of community detection, i.e. graph clustering, consists of partitioning # task of community detection, i.e. graph clustering, consists of partitioning
# the vertices in a graph into clusters in which nodes are more "similar" to # the vertices in a graph into clusters in which nodes are more "similar" to
# one another. # one another.
# #
# To generalize GNN to supervised community detection, Chen et al. introduced # To generalize GNN to supervised community detection, Chen et al. introduced
# a line-graph based variation of graph neural network in # a line-graph based variation of graph neural network in
# `Supervised Community Detection with Line Graph Neural Networks <https://arxiv.org/abs/1705.08415>`__. # `Supervised Community Detection with Line Graph Neural Networks <https://arxiv.org/abs/1705.08415>`__.
# One of the highlight of their model is # One of the highlight of their model is
# to augment the vanilla graph neural network(GNN) architecture to operate on # to augment the vanilla graph neural network(GNN) architecture to operate on
# the line graph of edge adajcencies, defined with non-backtracking operator. # the line graph of edge adajcencies, defined with non-backtracking operator.
# #
# In addition to its high performance, LGNN offers an opportunity to # In addition to its high performance, LGNN offers an opportunity to
# illustrate how DGL can implement an advanced graph algorithm by flexibly # illustrate how DGL can implement an advanced graph algorithm by flexibly
# mixing vanilla tensor operations, sparse-matrix multiplication and message- # mixing vanilla tensor operations, sparse-matrix multiplication and message-
# passing APIs. # passing APIs.
# #
# In the following sections, we will go through community detection, line # In the following sections, we will go through community detection, line
# graph, LGNN, and its implementation. # graph, LGNN, and its implementation.
# #
# Supervised Community Detection Task on CORA # Supervised Community Detection Task on CORA
# -------------------------------------------- # --------------------------------------------
# Community Detection # Community Detection
# ~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~
# In community detection task, we cluster "similar" nodes instead of # In community detection task, we cluster "similar" nodes instead of
# "labeling" them. The node similarity is typically described as higher inner # "labeling" them. The node similarity is typically described as higher inner
# density in each cluster. # density in each cluster.
# #
# What's the difference between community detection and node classification? # What's the difference between community detection and node classification?
# Comparing to node classification, community detection focuses on retrieving # Comparing to node classification, community detection focuses on retrieving
# cluster information in the graph, rather than assigning a specific label to # cluster information in the graph, rather than assigning a specific label to
# a node. For example, as long as a node is clusetered with its community # a node. For example, as long as a node is clusetered with its community
# members, it doesn't matter whether the node is assigned as "community A", # members, it doesn't matter whether the node is assigned as "community A",
# or "community B", while assigning all "great movies" to label "bad movies" # or "community B", while assigning all "great movies" to label "bad movies"
# will be a disaster in a movie network classification task. # will be a disaster in a movie network classification task.
# #
# What's the difference then, between a community detection algorithm and # What's the difference then, between a community detection algorithm and
# other clustering algorithm such as k-means? Community detection algorithm operates on # other clustering algorithm such as k-means? Community detection algorithm operates on
# graph-structured data. Comparing to k-means, community detection leverages # graph-structured data. Comparing to k-means, community detection leverages
# graph structure, instead of simply clustering nodes based on their # graph structure, instead of simply clustering nodes based on their
# features. # features.
# #
# CORA # CORA
# ~~~~~ # ~~~~~
# To be consistent with Graph Convolutional Network tutorial, # To be consistent with Graph Convolutional Network tutorial,
# we use `CORA <https://linqs.soe.ucsc.edu/data>`__ # we use `CORA <https://linqs.soe.ucsc.edu/data>`__
# to illustrate a simple community detection task. To refresh our memory, # to illustrate a simple community detection task. To refresh our memory,
# CORA is a scientific publication dataset, with 2708 papers belonging to 7 # CORA is a scientific publication dataset, with 2708 papers belonging to 7
# different mahcine learning sub-fields. Here, we formulate CORA as a # different mahcine learning sub-fields. Here, we formulate CORA as a
# directed graph, with each node being a paper, and each edge being a # directed graph, with each node being a paper, and each edge being a
# citation link (A->B means A cites B). Here is a visualization of the whole # citation link (A->B means A cites B). Here is a visualization of the whole
# CORA dataset. # CORA dataset.
# #
# .. figure:: https://i.imgur.com/X404Byc.png # .. figure:: https://i.imgur.com/X404Byc.png
# :alt: cora # :alt: cora
# :height: 400px # :height: 400px
# :width: 500px # :width: 500px
# :align: center # :align: center
# #
# CORA naturally contains 7 "classes", and statistics below show that each # CORA naturally contains 7 "classes", and statistics below show that each
# "class" does satisfy our assumption of community, i.e. nodes of same class # "class" does satisfy our assumption of community, i.e. nodes of same class
# class have higher connection probability among them than with nodes of different class. # class have higher connection probability among them than with nodes of different class.
# The following code snippet verifies that there are more intra-class edges # The following code snippet verifies that there are more intra-class edges
# than inter-class: # than inter-class:
import torch import torch
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
from dgl.data import citation_graph as citegrh from dgl.data import citation_graph as citegrh
data = citegrh.load_cora() data = citegrh.load_cora()
G = dgl.DGLGraph(data.graph) G = dgl.DGLGraph(data.graph)
labels = th.tensor(data.labels) labels = th.tensor(data.labels)
# find all the nodes labeled with class 0 # find all the nodes labeled with class 0
label0_nodes = th.nonzero(labels == 0).squeeze() label0_nodes = th.nonzero(labels == 0).squeeze()
# find all the edges pointing to class 0 nodes # find all the edges pointing to class 0 nodes
src, _ = G.in_edges(label0_nodes) src, _ = G.in_edges(label0_nodes)
src_labels = labels[src] src_labels = labels[src]
# find all the edges whose both endpoints are in class 0 # find all the edges whose both endpoints are in class 0
intra_src = th.nonzero(src_labels == 0) intra_src = th.nonzero(src_labels == 0)
print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels))) print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels)))
########################################################################################### ###########################################################################################
# Binary Community Subgraph from CORA -- a Toy Dataset # Binary Community Subgraph from CORA -- a Toy Dataset
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Without loss of generality, in this tutorial we limit the scope of our # Without loss of generality, in this tutorial we limit the scope of our
# task to binary community detection. # task to binary community detection.
# #
# .. note:: # .. note::
# #
# To create a toy binary-community dataset from CORA, We first extract # To create a toy binary-community dataset from CORA, We first extract
# all two-class pairs from the original CORA 7 classes. For each pair, we # all two-class pairs from the original CORA 7 classes. For each pair, we
# treat each class as one community, and find the largest subgraph that # treat each class as one community, and find the largest subgraph that
# at least contain one cross-community edge as the training example. As # at least contain one cross-community edge as the training example. As
# a result, there are a total of 21 training samples in this mini-dataset. # a result, there are a total of 21 training samples in this mini-dataset.
# #
# Here we visualize one of the training samples and its community structure: # Here we visualize one of the training samples and its community structure:
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
train_set = dgl.data.CoraBinary() train_set = dgl.data.CoraBinary()
G1, pmpd1, label1 = train_set[1] G1, pmpd1, label1 = train_set[1]
nx_G1 = G1.to_networkx() nx_G1 = G1.to_networkx()
def visualize(labels, g): def visualize(labels, g):
pos = nx.spring_layout(g, seed=1) pos = nx.spring_layout(g, seed=1)
plt.figure(figsize=(8, 8)) plt.figure(figsize=(8, 8))
plt.axis('off') plt.axis('off')
nx.draw_networkx(g, pos=pos, node_size=50, cmap=plt.get_cmap('coolwarm'), nx.draw_networkx(g, pos=pos, node_size=50, cmap=plt.get_cmap('coolwarm'),
node_color=labels, edge_color='k', node_color=labels, edge_color='k',
arrows=False, width=0.5, style='dotted', with_labels=False) arrows=False, width=0.5, style='dotted', with_labels=False)
visualize(label1, nx_G1) visualize(label1, nx_G1)
########################################################################################### ###########################################################################################
# Interested readers can go to the original paper to see how to generalize # Interested readers can go to the original paper to see how to generalize
# to multi communities case. # to multi communities case.
# #
# Community Detection in a Supervised Setting # Community Detection in a Supervised Setting
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Community Detection problem could be tackled with both supervised and # Community Detection problem could be tackled with both supervised and
# unsupervised approaches. Same as the original paper, we formulate # unsupervised approaches. Same as the original paper, we formulate
# Community Detection in a supervised setting as follows: # Community Detection in a supervised setting as follows:
# #
# - Each training example consists of :math:`(G, L)`, where :math:`G` is a # - Each training example consists of :math:`(G, L)`, where :math:`G` is a
# directed graph :math:`(V, E)`. For each node :math:`v` in :math:`V`, we # directed graph :math:`(V, E)`. For each node :math:`v` in :math:`V`, we
# assign a ground truth community label :math:`z_v \in \{0,1\}`. # assign a ground truth community label :math:`z_v \in \{0,1\}`.
# - The parameterized model :math:`f(G, \theta)` predicts a label set # - The parameterized model :math:`f(G, \theta)` predicts a label set
# :math:`\tilde{Z} = f(G)` for nodes :math:`V`. # :math:`\tilde{Z} = f(G)` for nodes :math:`V`.
# - For each example :math:`(G,L)`, the model learns to minimize a specially # - For each example :math:`(G,L)`, the model learns to minimize a specially
# designed loss function (equivariant loss) :math:`L_{equivariant} = # designed loss function (equivariant loss) :math:`L_{equivariant} =
# (\tilde{Z},Z)` # (\tilde{Z},Z)`
# #
# .. note:: # .. note::
# #
# In this supervised setting, the model naturally predicts a "label" for # In this supervised setting, the model naturally predicts a "label" for
# each community. However, community assignment should be equivariant to # each community. However, community assignment should be equivariant to
# label permutations. To acheive this, in each forward process, we take # label permutations. To acheive this, in each forward process, we take
# the minimum among losses calcuated from all possible permutations of # the minimum among losses calcuated from all possible permutations of
# labels. # labels.
# #
# Mathematically, this means # Mathematically, this means
# :math:`L_{equivariant} = \underset{\pi \in S_c} {min}-\log(\hat{\pi}, \pi)`, # :math:`L_{equivariant} = \underset{\pi \in S_c} {min}-\log(\hat{\pi}, \pi)`,
# where :math:`S_c` is the set of all permutations of labels, and # where :math:`S_c` is the set of all permutations of labels, and
# :math:`\hat{\pi}` is the set of predicted labels, # :math:`\hat{\pi}` is the set of predicted labels,
# :math:`- \log(\hat{\pi},\pi)` denotes negative log likelihood. # :math:`- \log(\hat{\pi},\pi)` denotes negative log likelihood.
# #
# For instance, for a toy graph with node :math:`\{1,2,3,4\}` and # For instance, for a toy graph with node :math:`\{1,2,3,4\}` and
# community assignment :math:`\{A, A, A, B\}`, with each node's label # community assignment :math:`\{A, A, A, B\}`, with each node's label
# :math:`l \in \{0,1\}`,The group of all possible permutations # :math:`l \in \{0,1\}`,The group of all possible permutations
# :math:`S_c = \{\{0,0,0,1\}, \{1,1,1,0\}\}`. # :math:`S_c = \{\{0,0,0,1\}, \{1,1,1,0\}\}`.
# #
# Line Graph Neural network: key ideas # Line Graph Neural network: key ideas
# ------------------------------------ # ------------------------------------
# An key innovation in this paper is the use of line-graph. # An key innovation in this paper is the use of line-graph.
# Unlike models in previous tutorials, message passing happens not only on the # Unlike models in previous tutorials, message passing happens not only on the
# original graph, e.g. the binary community subgraph from CORA, but also on the # original graph, e.g. the binary community subgraph from CORA, but also on the
# line-graph associated with the original graph. # line-graph associated with the original graph.
# #
# What's a line-graph ? # What's a line-graph ?
# ~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~
# In graph theory, line graph is a graph representation that encodes the # In graph theory, line graph is a graph representation that encodes the
# edge adjacency sturcutre in the original graph. # edge adjacency sturcutre in the original graph.
# #
# Specifically, a line-graph :math:`L(G)` turns an edge of the original graph `G` # Specifically, a line-graph :math:`L(G)` turns an edge of the original graph `G`
# into a node. This is illustrated with the graph below (taken from the # into a node. This is illustrated with the graph below (taken from the
# paper) # paper)
# #
# .. figure:: https://i.imgur.com/4WO5jEm.png # .. figure:: https://i.imgur.com/4WO5jEm.png
# :alt: lg # :alt: lg
# :align: center # :align: center
# #
# Here, :math:`e_{A}:= (i\rightarrow j)` and :math:`e_{B}:= (j\rightarrow k)` # Here, :math:`e_{A}:= (i\rightarrow j)` and :math:`e_{B}:= (j\rightarrow k)`
# are two edges in the original graph :math:`G`. In line graph :math:`G_L`, # are two edges in the original graph :math:`G`. In line graph :math:`G_L`,
# they correspond to nodes :math:`v^{l}_{A}, v^{l}_{B}`. # they correspond to nodes :math:`v^{l}_{A}, v^{l}_{B}`.
# #
# The next natural question is, how to connect nodes in line-graph? How to # The next natural question is, how to connect nodes in line-graph? How to
# connect two "edges"? Here, we use the following connection rule: # connect two "edges"? Here, we use the following connection rule:
# #
# Two nodes :math:`v^{l}_{A}`, :math:`v^{l}_{B}` in `lg` are connected if # Two nodes :math:`v^{l}_{A}`, :math:`v^{l}_{B}` in `lg` are connected if
# the corresponding two edges :math:`e_{A}, e_{B}` in `g` share one and only # the corresponding two edges :math:`e_{A}, e_{B}` in `g` share one and only
# one node: # one node:
# :math:`e_{A}`'s destination node is :math:`e_{B}`'s source node # :math:`e_{A}`'s destination node is :math:`e_{B}`'s source node
# (:math:`j`). # (:math:`j`).
# #
# .. note:: # .. note::
# #
# Mathematically, this definition corresponds to a notion called non-backtracking # Mathematically, this definition corresponds to a notion called non-backtracking
# operator: # operator:
# :math:`B_{(i \rightarrow j), (\hat{i} \rightarrow \hat{j})}` # :math:`B_{(i \rightarrow j), (\hat{i} \rightarrow \hat{j})}`
# :math:`= \begin{cases} # :math:`= \begin{cases}
# 1 \text{ if } j = \hat{i}, \hat{j} \neq i\\ # 1 \text{ if } j = \hat{i}, \hat{j} \neq i\\
# 0 \text{ otherwise} \end{cases}` # 0 \text{ otherwise} \end{cases}`
# where an edge is formed if :math:`B_{node1, node2} = 1`. # where an edge is formed if :math:`B_{node1, node2} = 1`.
# #
# #
# One layer in LGNN -- algorithm sturcture # One layer in LGNN -- algorithm sturcture
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# LGNN chains up a series of line-graph neural network layers. The graph # LGNN chains up a series of line-graph neural network layers. The graph
# reprentation :math:`x` and its line-graph companion :math:`y` evolve with # reprentation :math:`x` and its line-graph companion :math:`y` evolve with
# the dataflow as follows, # the dataflow as follows,
# #
# .. figure:: https://i.imgur.com/bZGGIGp.png # .. figure:: https://i.imgur.com/bZGGIGp.png
# :alt: alg # :alt: alg
# :align: center # :align: center
# #
# At the :math:`k`-th layer, the :math:`i`-th neuron of the :math:`l`-th # At the :math:`k`-th layer, the :math:`i`-th neuron of the :math:`l`-th
# channel updates its embedding :math:`x^{(k+1)}_{i,l}` with: # channel updates its embedding :math:`x^{(k+1)}_{i,l}` with:
# #
# .. math:: # .. math::
# \begin{split} # \begin{split}
# x^{(k+1)}_{i,l} ={}&\rho[x^{(k)}_{i}\theta^{(k)}_{1,l} # x^{(k+1)}_{i,l} ={}&\rho[x^{(k)}_{i}\theta^{(k)}_{1,l}
# +(Dx^{(k)})_{i}\theta^{(k)}_{2,l} \\ # +(Dx^{(k)})_{i}\theta^{(k)}_{2,l} \\
# &+\sum^{J-1}_{j=0}(A^{2^{j}}x^{k})_{i}\theta^{(k)}_{3+j,l}\\ # &+\sum^{J-1}_{j=0}(A^{2^{j}}x^{k})_{i}\theta^{(k)}_{3+j,l}\\
# &+[\{\text{Pm},\text{Pd}\}y^{(k)}]_{i}\theta^{(k)}_{3+J,l}] \\ # &+[\{\text{Pm},\text{Pd}\}y^{(k)}]_{i}\theta^{(k)}_{3+J,l}] \\
# &+\text{skip-connection} # &+\text{skip-connection}
# \qquad i \in V, l = 1,2,3, ... b_{k+1}/2 # \qquad i \in V, l = 1,2,3, ... b_{k+1}/2
# \end{split} # \end{split}
# #
# Then, the line-graph representation :math:`y^{(k+1)}_{i,l}` with, # Then, the line-graph representation :math:`y^{(k+1)}_{i,l}` with,
# #
# .. math:: # .. math::
# #
# \begin{split} # \begin{split}
# y^{(k+1)}_{i',l^{'}} = {}&\rho[y^{(k)}_{i^{'}}\gamma^{(k)}_{1,l^{'}}+ # y^{(k+1)}_{i',l^{'}} = {}&\rho[y^{(k)}_{i^{'}}\gamma^{(k)}_{1,l^{'}}+
# (D_{L(G)}y^{(k)})_{i^{'}}\gamma^{(k)}_{2,l^{'}}\\ # (D_{L(G)}y^{(k)})_{i^{'}}\gamma^{(k)}_{2,l^{'}}\\
# &+\sum^{J-1}_{j=0}(A_{L(G)}^{2^{j}}y^{k})_{i}\gamma^{(k)}_{3+j,l^{'}}\\ # &+\sum^{J-1}_{j=0}(A_{L(G)}^{2^{j}}y^{k})_{i}\gamma^{(k)}_{3+j,l^{'}}\\
# &+[\{\text{Pm},\text{Pd}\}^{T}x^{(k+1)}]_{i^{'}}\gamma^{(k)}_{3+J,l^{'}}]\\ # &+[\{\text{Pm},\text{Pd}\}^{T}x^{(k+1)}]_{i^{'}}\gamma^{(k)}_{3+J,l^{'}}]\\
# &+\text{skip-connection} # &+\text{skip-connection}
# \qquad i^{'} \in V_{l}, l^{'} = 1,2,3, ... b^{'}_{k+1}/2 # \qquad i^{'} \in V_{l}, l^{'} = 1,2,3, ... b^{'}_{k+1}/2
# \end{split} # \end{split}
# #
# Where :math:`\text{skip-connection}` refers to performing the same operation without the non-linearity # Where :math:`\text{skip-connection}` refers to performing the same operation without the non-linearity
# :math:`\rho`, and with linear projection :math:`\theta_\{\frac{b_{k+1}}{2} + 1, ..., b_{k+1}-1, b_{k+1}\}` # :math:`\rho`, and with linear projection :math:`\theta_\{\frac{b_{k+1}}{2} + 1, ..., b_{k+1}-1, b_{k+1}\}`
# and :math:`\gamma_\{\frac{b_{k+1}}{2} + 1, ..., b_{k+1}-1, b_{k+1}\}`. # and :math:`\gamma_\{\frac{b_{k+1}}{2} + 1, ..., b_{k+1}-1, b_{k+1}\}`.
# #
# Implement LGNN in DGL # Implement LGNN in DGL
# --------------------- # ---------------------
# General idea # General idea
# ~~~~~~~~~~~~ # ~~~~~~~~~~~~
# The above equations look intimidating. However, we observe the following: # The above equations look intimidating. However, we observe the following:
# #
# - The two equations are symmetric and can be implemented as two instances # - The two equations are symmetric and can be implemented as two instances
# of the same class with different parameters. # of the same class with different parameters.
# Mainly, the first equation operates on graph representation :math:`x`, # Mainly, the first equation operates on graph representation :math:`x`,
# whereas the second operates on line-graph # whereas the second operates on line-graph
# representation :math:`y`. Let us denote this abstraction as :math:`f`. Then # representation :math:`y`. Let us denote this abstraction as :math:`f`. Then
# the first is :math:`f(x,y; \theta_x)`, and the second # the first is :math:`f(x,y; \theta_x)`, and the second
# is :math:`f(y,x, \theta_y)`. That is, they are parameterized to compute # is :math:`f(y,x, \theta_y)`. That is, they are parameterized to compute
# representations of the original graph and its # representations of the original graph and its
# companion line graph, respectively. # companion line graph, respectively.
# #
# - Each equation consists of 4 terms (take the first one as an example): # - Each equation consists of 4 terms (take the first one as an example):
# #
# - :math:`x^{(k)}\theta^{(k)}_{1,l}`, a linear projection of previous # - :math:`x^{(k)}\theta^{(k)}_{1,l}`, a linear projection of previous
# layer's output :math:`x^{(k)}`, denote as :math:`\text{prev}(x)`. # layer's output :math:`x^{(k)}`, denote as :math:`\text{prev}(x)`.
# - :math:`(Dx^{(k)})\theta^{(k)}_{2,l}`, a linear projection of degree # - :math:`(Dx^{(k)})\theta^{(k)}_{2,l}`, a linear projection of degree
# operator on :math:`x^{(k)}`, denote as :math:`\text{deg}(x)`. # operator on :math:`x^{(k)}`, denote as :math:`\text{deg}(x)`.
# - :math:`\sum^{J-1}_{j=0}(A^{2^{j}}x^{(k)})\theta^{(k)}_{3+j,l}`, # - :math:`\sum^{J-1}_{j=0}(A^{2^{j}}x^{(k)})\theta^{(k)}_{3+j,l}`,
# a summation of :math:`2^{j}` adjacency operator on :math:`x^{(k)}`, # a summation of :math:`2^{j}` adjacency operator on :math:`x^{(k)}`,
# denote as :math:`\text{radius}(x)` # denote as :math:`\text{radius}(x)`
# - :math:`[\{Pm,Pd\}y^{(k)}]\theta^{(k)}_{3+J,l}`, fusing another # - :math:`[\{Pm,Pd\}y^{(k)}]\theta^{(k)}_{3+J,l}`, fusing another
# graph's embedding information using incidence matrix # graph's embedding information using incidence matrix
# :math:`\{Pm, Pd\}`, followed with a linear porjection, # :math:`\{Pm, Pd\}`, followed with a linear porjection,
# denote as :math:`\text{fuse}(y)`. # denote as :math:`\text{fuse}(y)`.
# #
# - In addition, each of the terms are performed again with different # - In addition, each of the terms are performed again with different
# parameters, and without the nonlinearity after the sum. # parameters, and without the nonlinearity after the sum.
# Therefore, :math:`f` could be written as: # Therefore, :math:`f` could be written as:
# #
# .. math:: # .. math::
# \begin{split} # \begin{split}
# f(x^{(k)},y^{(k)}) = {}\rho[&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1}) # f(x^{(k)},y^{(k)}) = {}\rho[&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1})
# +\text{fuse}(y^{(k)})]\\ # +\text{fuse}(y^{(k)})]\\
# +&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1}) +\text{fuse}(y^{(k)}) # +&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1}) +\text{fuse}(y^{(k)})
# \end{split} # \end{split}
# #
# - Two equations are chained up in the following order : # - Two equations are chained up in the following order :
# #
# .. math:: # .. math::
# \begin{split} # \begin{split}
# x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\ # x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\
# y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)}) # y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)})
# \end{split} # \end{split}
# #
# With these observations, we proceed to implementation. # With these observations, we proceed to implementation.
# The important point is we are to use different strategies for these terms. # The important point is we are to use different strategies for these terms.
# #
# .. note:: # .. note::
# For a detailed explanation of :math:`\{Pm, Pd\}`, please go to `Advanced Topic`_. # For a detailed explanation of :math:`\{Pm, Pd\}`, please go to `Advanced Topic`_.
# #
# Implementing :math:`\text{prev}` and :math:`\text{deg}` as tensor operation # Implementing :math:`\text{prev}` and :math:`\text{deg}` as tensor operation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Since linear projection and degree operation are both simply matrix # Since linear projection and degree operation are both simply matrix
# multiplication, we can write them as PyTorch tensor operation. # multiplication, we can write them as PyTorch tensor operation.
# #
# In ``__init__``, we define the projection variables: # In ``__init__``, we define the projection variables:
# #
# :: # ::
# #
# self.linear_prev = nn.Linear(in_feats, out_feats) # self.linear_prev = nn.Linear(in_feats, out_feats)
# self.linear_deg = nn.Linear(in_feats, out_feats) # self.linear_deg = nn.Linear(in_feats, out_feats)
# #
# #
# In ``forward()``, :math:`\text{prev}` and :math:`\text{deg}` are the same # In ``forward()``, :math:`\text{prev}` and :math:`\text{deg}` are the same
# as any other PyTorch tensor operations. # as any other PyTorch tensor operations.
# #
# :: # ::
# #
# prev_proj = self.linear_prev(feat_a) # prev_proj = self.linear_prev(feat_a)
# deg_proj = self.linear_deg(deg * feat_a) # deg_proj = self.linear_deg(deg * feat_a)
# #
# Implementing :math:`\text{radius}` as message passing in DGL # Implementing :math:`\text{radius}` as message passing in DGL
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# As discussed in GCN tutorial, we can formulate one adjacency operator as # As discussed in GCN tutorial, we can formulate one adjacency operator as
# doing one step message passing. As a generalization, :math:`2^j` adjacency # doing one step message passing. As a generalization, :math:`2^j` adjacency
# operations can be formulated as performing :math:`2^j` step of message # operations can be formulated as performing :math:`2^j` step of message
# passing. Therefore, the summation is equivalent to summing nodes' # passing. Therefore, the summation is equivalent to summing nodes'
# representation of :math:`2^j, j=0, 1, 2..` step messsage passing, i.e. # representation of :math:`2^j, j=0, 1, 2..` step messsage passing, i.e.
# gathering information in :math:`2^{j}` neighbourhood of each node. # gathering information in :math:`2^{j}` neighbourhood of each node.
# #
# In ``__init__``, we define the projection variables used in each # In ``__init__``, we define the projection variables used in each
# :math:`2^j` steps of message passing: # :math:`2^j` steps of message passing:
# #
# :: # ::
# #
# self.linear_radius = nn.ModuleList( # self.linear_radius = nn.ModuleList(
# [nn.Linear(in_feats, out_feats) for i in range(radius)]) # [nn.Linear(in_feats, out_feats) for i in range(radius)])
# #
# In ``__forward__``, we use following function ``aggregate_radius()`` to # In ``__forward__``, we use following function ``aggregate_radius()`` to
# gather data from multiple hop. Note that the ``update_all`` is called # gather data from multiple hop. Note that the ``update_all`` is called
# multiple times. # multiple times.
# Return a list containing features gathered from multiple radius. # Return a list containing features gathered from multiple radius.
import dgl.function as fn import dgl.function as fn
def aggregate_radius(radius, g, z): def aggregate_radius(radius, g, z):
# initializing list to collect message passing result # initializing list to collect message passing result
z_list = [] z_list = []
g.ndata['z'] = z g.ndata['z'] = z
# pulling message from 1-hop neighbourhood # pulling message from 1-hop neighbourhood
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.ndata['z']) z_list.append(g.ndata['z'])
for i in range(radius - 1): for i in range(radius - 1):
for j in range(2 ** i): for j in range(2 ** i):
#pulling message from 2^j neighborhood #pulling message from 2^j neighborhood
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z')) g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.ndata['z']) z_list.append(g.ndata['z'])
return z_list return z_list
######################################################################### #########################################################################
# Implementing :math:`\text{fuse}` as sparse matrix multiplication # Implementing :math:`\text{fuse}` as sparse matrix multiplication
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# :math:`\{Pm, Pd\}` is a sparse matrix with only two non-zero entries on # :math:`\{Pm, Pd\}` is a sparse matrix with only two non-zero entries on
# each column. Therefore, we construct it as a sparse matrix in the dataset, # each column. Therefore, we construct it as a sparse matrix in the dataset,
# and implement :math:`\text{fuse}` as a sparse matrix multiplication. # and implement :math:`\text{fuse}` as a sparse matrix multiplication.
# #
# in ``__forward__``: # in ``__forward__``:
# #
# :: # ::
# #
# fuse = self.linear_fuse(th.mm(pm_pd, feat_b)) # fuse = self.linear_fuse(th.mm(pm_pd, feat_b))
# #
# Completing :math:`f(x, y)` # Completing :math:`f(x, y)`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# Finally, we sum up all the terms together, pass it to skip connection and # Finally, we sum up all the terms together, pass it to skip connection and
# batch-norm. # batch-norm.
# #
# :: # ::
# #
# result = prev_proj + deg_proj + radius_proj + fuse # result = prev_proj + deg_proj + radius_proj + fuse
# #
# Then pass result to skip connection: # Then pass result to skip connection:
# #
# :: # ::
# #
# result = th.cat([result[:, :n], F.relu(result[:, n:])], 1) # result = th.cat([result[:, :n], F.relu(result[:, n:])], 1)
# #
# Then batch norm # Then batch norm
# #
# :: # ::
# #
# result = self.bn(result) #Batch Normalization. # result = self.bn(result) #Batch Normalization.
# #
# #
# Below is the complete code for one LGNN layer's abstraction :math:`f(x,y)` # Below is the complete code for one LGNN layer's abstraction :math:`f(x,y)`
class LGNNCore(nn.Module): class LGNNCore(nn.Module):
def __init__(self, in_feats, out_feats, radius): def __init__(self, in_feats, out_feats, radius):
super(LGNNCore, self).__init__() super(LGNNCore, self).__init__()
self.out_feats = out_feats self.out_feats = out_feats
self.radius = radius self.radius = radius
self.linear_prev = nn.Linear(in_feats, out_feats) self.linear_prev = nn.Linear(in_feats, out_feats)
self.linear_deg = nn.Linear(in_feats, out_feats) self.linear_deg = nn.Linear(in_feats, out_feats)
self.linear_radius = nn.ModuleList( self.linear_radius = nn.ModuleList(
[nn.Linear(in_feats, out_feats) for i in range(radius)]) [nn.Linear(in_feats, out_feats) for i in range(radius)])
self.linear_fuse = nn.Linear(in_feats, out_feats) self.linear_fuse = nn.Linear(in_feats, out_feats)
self.bn = nn.BatchNorm1d(out_feats) self.bn = nn.BatchNorm1d(out_feats)
def forward(self, g, feat_a, feat_b, deg, pm_pd): def forward(self, g, feat_a, feat_b, deg, pm_pd):
# term "prev" # term "prev"
prev_proj = self.linear_prev(feat_a) prev_proj = self.linear_prev(feat_a)
# term "deg" # term "deg"
deg_proj = self.linear_deg(deg * feat_a) deg_proj = self.linear_deg(deg * feat_a)
# term "radius" # term "radius"
# aggregate 2^j-hop features # aggregate 2^j-hop features
hop2j_list = aggregate_radius(self.radius, g, feat_a) hop2j_list = aggregate_radius(self.radius, g, feat_a)
# apply linear transformation # apply linear transformation
hop2j_list = [linear(x) for linear, x in zip(self.linear_radius, hop2j_list)] hop2j_list = [linear(x) for linear, x in zip(self.linear_radius, hop2j_list)]
radius_proj = sum(hop2j_list) radius_proj = sum(hop2j_list)
# term "fuse" # term "fuse"
fuse = self.linear_fuse(th.mm(pm_pd, feat_b)) fuse = self.linear_fuse(th.mm(pm_pd, feat_b))
# sum them together # sum them together
result = prev_proj + deg_proj + radius_proj + fuse result = prev_proj + deg_proj + radius_proj + fuse
# skip connection and batch norm # skip connection and batch norm
n = self.out_feats // 2 n = self.out_feats // 2
result = th.cat([result[:, :n], F.relu(result[:, n:])], 1) result = th.cat([result[:, :n], F.relu(result[:, n:])], 1)
result = self.bn(result) result = self.bn(result)
return result return result
############################################################################################################## ##############################################################################################################
# Chain up LGNN abstractions as a LGNN layer # Chain up LGNN abstractions as a LGNN layer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# To implement: # To implement:
# #
# .. math:: # .. math::
# \begin{split} # \begin{split}
# x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\ # x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\
# y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)}) # y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)})
# \end{split} # \end{split}
# #
# We chain up two ``LGNNCore`` instances with different parameter in the forward pass. # We chain up two ``LGNNCore`` instances with different parameter in the forward pass.
class LGNNLayer(nn.Module): class LGNNLayer(nn.Module):
def __init__(self, in_feats, out_feats, radius): def __init__(self, in_feats, out_feats, radius):
super(LGNNLayer, self).__init__() super(LGNNLayer, self).__init__()
self.g_layer = LGNNCore(in_feats, out_feats, radius) self.g_layer = LGNNCore(in_feats, out_feats, radius)
self.lg_layer = LGNNCore(in_feats, out_feats, radius) self.lg_layer = LGNNCore(in_feats, out_feats, radius)
def forward(self, g, lg, x, lg_x, deg_g, deg_lg, pm_pd): def forward(self, g, lg, x, lg_x, deg_g, deg_lg, pm_pd):
next_x = self.g_layer(g, x, lg_x, deg_g, pm_pd) next_x = self.g_layer(g, x, lg_x, deg_g, pm_pd)
pm_pd_y = th.transpose(pm_pd, 0, 1) pm_pd_y = th.transpose(pm_pd, 0, 1)
next_lg_x = self.lg_layer(lg, lg_x, x, deg_lg, pm_pd_y) next_lg_x = self.lg_layer(lg, lg_x, x, deg_lg, pm_pd_y)
return next_x, next_lg_x return next_x, next_lg_x
######################################################################################## ########################################################################################
# Chain up LGNN layers # Chain up LGNN layers
# ~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~
# We then define an LGNN with three hidden layers. # We then define an LGNN with three hidden layers.
class LGNN(nn.Module): class LGNN(nn.Module):
def __init__(self, radius): def __init__(self, radius):
super(LGNN, self).__init__() super(LGNN, self).__init__()
self.layer1 = LGNNLayer(1, 16, radius) # input is scalar feature self.layer1 = LGNNLayer(1, 16, radius) # input is scalar feature
self.layer2 = LGNNLayer(16, 16, radius) # hidden size is 16 self.layer2 = LGNNLayer(16, 16, radius) # hidden size is 16
self.layer3 = LGNNLayer(16, 16, radius) self.layer3 = LGNNLayer(16, 16, radius)
self.linear = nn.Linear(16, 2) # predice two classes self.linear = nn.Linear(16, 2) # predice two classes
def forward(self, g, lg, pm_pd): def forward(self, g, lg, pm_pd):
# compute the degrees # compute the degrees
deg_g = g.in_degrees().float().unsqueeze(1) deg_g = g.in_degrees().float().unsqueeze(1)
deg_lg = lg.in_degrees().float().unsqueeze(1) deg_lg = lg.in_degrees().float().unsqueeze(1)
# use degree as the input feature # use degree as the input feature
x, lg_x = deg_g, deg_lg x, lg_x = deg_g, deg_lg
x, lg_x = self.layer1(g, lg, x, lg_x, deg_g, deg_lg, pm_pd) x, lg_x = self.layer1(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
x, lg_x = self.layer2(g, lg, x, lg_x, deg_g, deg_lg, pm_pd) x, lg_x = self.layer2(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
x, lg_x = self.layer3(g, lg, x, lg_x, deg_g, deg_lg, pm_pd) x, lg_x = self.layer3(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
return self.linear(x) return self.linear(x)
######################################################################################### #########################################################################################
# Training and Inference # Training and Inference
# ----------------------- # -----------------------
# We first load the data # We first load the data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
training_loader = DataLoader(train_set, training_loader = DataLoader(train_set,
batch_size=1, batch_size=1,
collate_fn=train_set.collate_fn, collate_fn=train_set.collate_fn,
drop_last=True) drop_last=True)
####################################################################################### #######################################################################################
# We then define the main training loop. Note that each training sample contains # We then define the main training loop. Note that each training sample contains
# three objects: a :class:`~dgl.DGLGraph`, a scipy sparse matrix ``pmpd`` and label # three objects: a :class:`~dgl.DGLGraph`, a scipy sparse matrix ``pmpd`` and label
# array in ``numpy.ndarray``. We first generate the line graph using: # array in ``numpy.ndarray``. We first generate the line graph using:
# #
# :: # ::
# #
# lg = g.line_graph(backtracking=False) # lg = g.line_graph(backtracking=False)
# #
# Note that ``backtracking=False`` is required to correctly simulate non-backtracking # Note that ``backtracking=False`` is required to correctly simulate non-backtracking
# operation. We also define a utility function to convert the scipy sparse matrix to # operation. We also define a utility function to convert the scipy sparse matrix to
# torch sparse tensor. # torch sparse tensor.
# create the model # create the model
model = LGNN(radius=3) model = LGNN(radius=3)
# define the optimizer # define the optimizer
optimizer = th.optim.Adam(model.parameters(), lr=1e-2) optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
# a util function to convert a scipy.coo_matrix to torch.SparseFloat # a util function to convert a scipy.coo_matrix to torch.SparseFloat
def sparse2th(mat): def sparse2th(mat):
value = mat.data value = mat.data
indices = th.LongTensor([mat.row, mat.col]) indices = th.LongTensor([mat.row, mat.col])
tensor = th.sparse.FloatTensor(indices, th.from_numpy(value).float(), mat.shape) tensor = th.sparse.FloatTensor(indices, th.from_numpy(value).float(), mat.shape)
return tensor return tensor
# train for 20 epochs # train for 20 epochs
for i in range(20): for i in range(20):
all_loss = [] all_loss = []
all_acc = [] all_acc = []
for [g, pmpd, label] in training_loader: for [g, pmpd, label] in training_loader:
# Generate the line graph. # Generate the line graph.
lg = g.line_graph(backtracking=False) lg = g.line_graph(backtracking=False)
# Create torch tensors # Create torch tensors
pmpd = sparse2th(pmpd) pmpd = sparse2th(pmpd)
label = th.from_numpy(label) label = th.from_numpy(label)
# Forward # Forward
z = model(g, lg, pmpd) z = model(g, lg, pmpd)
# Calculate loss: # Calculate loss:
# Since there are only two communities, there are only two permutations # Since there are only two communities, there are only two permutations
# of the community labels. # of the community labels.
loss_perm1 = F.cross_entropy(z, label) loss_perm1 = F.cross_entropy(z, label)
loss_perm2 = F.cross_entropy(z, 1 - label) loss_perm2 = F.cross_entropy(z, 1 - label)
loss = th.min(loss_perm1, loss_perm2) loss = th.min(loss_perm1, loss_perm2)
# Calculate accuracy: # Calculate accuracy:
_, pred = th.max(z, 1) _, pred = th.max(z, 1)
acc_perm1 = (pred == label).float().mean() acc_perm1 = (pred == label).float().mean()
acc_perm2 = (pred == 1 - label).float().mean() acc_perm2 = (pred == 1 - label).float().mean()
acc = th.max(acc_perm1, acc_perm2) acc = th.max(acc_perm1, acc_perm2)
all_loss.append(loss.item()) all_loss.append(loss.item())
all_acc.append(acc.item()) all_acc.append(acc.item())
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
niters = len(all_loss) niters = len(all_loss)
print("Epoch %d | loss %.4f | accuracy %.4f" % (i, print("Epoch %d | loss %.4f | accuracy %.4f" % (i,
sum(all_loss) / niters, sum(all_acc) / niters)) sum(all_loss) / niters, sum(all_acc) / niters))
####################################################################################### #######################################################################################
# Visualize training progress # Visualize training progress
# ----------------------------- # -----------------------------
# We visualize the network's community prediction on one training example, # We visualize the network's community prediction on one training example,
# together with the ground truth. # together with the ground truth.
pmpd1 = sparse2th(pmpd1) pmpd1 = sparse2th(pmpd1)
LG1 = G1.line_graph(backtracking=False) LG1 = G1.line_graph(backtracking=False)
z = model(G1, LG1, pmpd1) z = model(G1, LG1, pmpd1)
_, pred = th.max(z, 1) _, pred = th.max(z, 1)
visualize(pred, nx_G1) visualize(pred, nx_G1)
####################################################################################### #######################################################################################
# Compared with the ground truth. Note that the color might be reversed for the # Compared with the ground truth. Note that the color might be reversed for the
# two community as the model is to correctly predict the "partitioning". # two community as the model is to correctly predict the "partitioning".
visualize(label1, nx_G1) visualize(label1, nx_G1)
######################################### #########################################
# Here is an animation to better understand the process. (40 epochs) # Here is an animation to better understand the process. (40 epochs)
# #
# .. figure:: https://i.imgur.com/KDUyE1S.gif # .. figure:: https://i.imgur.com/KDUyE1S.gif
# :alt: lgnn-anim # :alt: lgnn-anim
# #
# Advanced topic # Advanced topic
# -------------- # --------------
# #
# Batching # Batching
# ~~~~~~~~ # ~~~~~~~~
# LGNN takes a collection of different graphs. # LGNN takes a collection of different graphs.
# Thus, it's natural we use batching to explore parallelism. # Thus, it's natural we use batching to explore parallelism.
# Why is it not done? # Why is it not done?
# #
# As it turned out, we moved batching into the dataloader itself. # As it turned out, we moved batching into the dataloader itself.
# In the ``collate_fn`` for PyTorch Dataloader, we batch graphs using DGL's # In the ``collate_fn`` for PyTorch Dataloader, we batch graphs using DGL's
# batched_graph API. To refresh our memory, DGL batches graphs by merging them # batched_graph API. To refresh our memory, DGL batches graphs by merging them
# into a large graph, with each smaller graph's adjacency matrix being a block # into a large graph, with each smaller graph's adjacency matrix being a block
# along the diagonal of the large graph's adjacency matrix. We concatentate # along the diagonal of the large graph's adjacency matrix. We concatentate
# :math`\{Pm,Pd\}` as block diagonal matrix in corespondance to DGL batched # :math`\{Pm,Pd\}` as block diagonal matrix in corespondance to DGL batched
# graph API. # graph API.
def collate_fn(batch): def collate_fn(batch):
graphs, pmpds, labels = zip(*batch) graphs, pmpds, labels = zip(*batch)
batched_graphs = dgl.batch(graphs) batched_graphs = dgl.batch(graphs)
batched_pmpds = sp.block_diag(pmpds) batched_pmpds = sp.block_diag(pmpds)
batched_labels = np.concatenate(labels, axis=0) batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels return batched_graphs, batched_pmpds, batched_labels
###################################################################################### ######################################################################################
# You can check out the complete code # You can check out the complete code
# `here <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/line_graph>`_. # `here <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/line_graph>`_.
# #
# What's the business with :math:`\{Pm, Pd\}`? # What's the business with :math:`\{Pm, Pd\}`?
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Rougly speaking, there is a relationship between how :math:`g` and # Rougly speaking, there is a relationship between how :math:`g` and
# :math:`lg` (the line graph) working together with loopy brief propagation. # :math:`lg` (the line graph) working together with loopy brief propagation.
# Here, we implement :math:`\{Pm, Pd\}` as scipy coo sparse matrix in the datset, # Here, we implement :math:`\{Pm, Pd\}` as scipy coo sparse matrix in the datset,
# and stack them as tensors when batching. Another batching solution is to # and stack them as tensors when batching. Another batching solution is to
# treat :math:`\{Pm, Pd\}` as the adjacency matrix of a bipartie graph, which maps # treat :math:`\{Pm, Pd\}` as the adjacency matrix of a bipartie graph, which maps
# line graph's feature to graph's, and vice versa. # line graph's feature to graph's, and vice versa.
.. _tutorials1-index:
Graph Neural Network and its variant
------------------------------------
* **GCN** `[paper] <https://arxiv.org/abs/1609.02907>`__ `[tutorial] <models/1_gcn.html>`__
`[code] <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/gcn/gcn.py>`__:
this is the vanilla GCN. The tutorial covers the basic uses of DGL APIs.
* **GAT** `[paper] <https://arxiv.org/abs/1710.10903>`__
`[code] <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/gat/gat.py>`__:
the key extension of GAT w.r.t vanilla GCN is deploying multi-head attention
among neighborhood of a node, thus greatly enhances the capacity and
expressiveness of the model.
* **R-GCN** `[paper] <https://arxiv.org/abs/1703.06103>`__ `[tutorial] <models/4_rgcn.html>`__
[code (wip)]: the key
difference of RGNN is to allow multi-edges among two entities of a graph, and
edges with distinct relationships are encoded differently. This is an
interesting extension of GCN that can have a lot of applications of its own.
* **LGNN** `[paper] <https://arxiv.org/abs/1705.08415>`__ `[tutorial (wip)]` `[code (wip)]`:
this model focuses on community detection by inspecting graph structures. It
uses representations of both the orignal graph and its line-graph companion. In
addition to demonstrate how an algorithm can harness multiple graphs, our
implementation shows how one can judiciously mix vanilla tensor operation,
sparse-matrix tensor operations, along with message-passing with DGL.
* **SSE** `[paper] <http://proceedings.mlr.press/v80/dai18a/dai18a.pdf>`__ `[tutorial (wip)]`
`[code] <https://github.com/jermainewang/dgl/blob/master/examples/mxnet/sse/sse_batch.py>`__:
the emphasize here is *giant* graph that cannot fit comfortably on one GPU
card. SSE is an example to illustrate the co-design of both algrithm and
system: sampling to guarantee asymptotic covergence while lowering the
complexity, and batching across samples for maximum parallelism.
\ No newline at end of file
""" """
.. _model-tree-lstm: .. _model-tree-lstm:
Tree LSTM DGL Tutorial Tree LSTM DGL Tutorial
========================= =========================
**Author**: Zihao Ye, Qipeng Guo, `Minjie Wang **Author**: Zihao Ye, Qipeng Guo, `Minjie Wang
<https://jermainewang.github.io/>`_, `Jake Zhao <https://jermainewang.github.io/>`_, `Jake Zhao
<https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang
""" """
############################################################################## ##############################################################################
# #
# Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015 # Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015
# paper: `Improved Semantic Representations From Tree-Structured Long # paper: `Improved Semantic Representations From Tree-Structured Long
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__. # Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__.
# The core idea is to introduce syntactic information for language tasks by # The core idea is to introduce syntactic information for language tasks by
# extending the chain-structured LSTM to a tree-structured LSTM. The Dependency # extending the chain-structured LSTM to a tree-structured LSTM. The Dependency
# Tree/Constituency Tree techniques were leveraged to obtain a ''latent tree''. # Tree/Constituency Tree techniques were leveraged to obtain a ''latent tree''.
# #
# One, if not all, difficulty of training Tree-LSTMs is batching --- a standard # One, if not all, difficulty of training Tree-LSTMs is batching --- a standard
# technique in machine learning to accelerate optimization. However, since trees # technique in machine learning to accelerate optimization. However, since trees
# generally have different shapes by nature, parallization becomes non trivial. # generally have different shapes by nature, parallization becomes non trivial.
# DGL offers an alternative: to pool all the trees into one single graph then # DGL offers an alternative: to pool all the trees into one single graph then
# induce the message passing over them guided by the structure of each tree. # induce the message passing over them guided by the structure of each tree.
# #
# The task and the dataset # The task and the dataset
# ------------------------ # ------------------------
# In this tutorial, we will use Tree-LSTMs for sentiment analysis. # In this tutorial, we will use Tree-LSTMs for sentiment analysis.
# We have wrapped the # We have wrapped the
# `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in # `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in
# ``dgl.data``. The dataset provides a fine-grained tree level sentiment # ``dgl.data``. The dataset provides a fine-grained tree level sentiment
# annotation: 5 classes(very negative, negative, neutral, positive, and # annotation: 5 classes(very negative, negative, neutral, positive, and
# very positive) that indicates the sentiment in current subtree. Non-leaf # very positive) that indicates the sentiment in current subtree. Non-leaf
# nodes in constituency tree does not contain words, we use a special # nodes in constituency tree does not contain words, we use a special
# ``PAD_WORD`` token to denote them, during the training/inferencing, # ``PAD_WORD`` token to denote them, during the training/inferencing,
# their embeddings would be masked to all-zero. # their embeddings would be masked to all-zero.
# #
# .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png # .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png
# :alt: # :alt:
# #
# The figure displays one sample of the SST dataset, which is a # The figure displays one sample of the SST dataset, which is a
# constituency parse tree with their nodes labeled with sentiment. To # constituency parse tree with their nodes labeled with sentiment. To
# speed up things, let's build a tiny set with 5 sentences and take a look # speed up things, let's build a tiny set with 5 sentences and take a look
# at the first one: # at the first one:
# #
import dgl import dgl
import dgl.data as data import dgl.data as data
# Each sample in the dataset is a constituency tree. The leaf nodes # Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is a int value stored in the "x" field. # represent words. The word is a int value stored in the "x" field.
# The non-leaf nodes has a special word PAD_WORD. The sentiment # The non-leaf nodes has a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field. # label is stored in the "y" feature field.
trainset = data.SST(mode='tiny') # the "tiny" set has only 5 trees trainset = data.SST(mode='tiny') # the "tiny" set has only 5 trees
tiny_sst = trainset.trees tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes num_classes = trainset.num_classes
vocab = trainset.vocab # vocabulary dict: key -> id vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word
a_tree = tiny_sst[0] a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist(): for token in a_tree.ndata['x'].tolist():
if token != trainset.PAD_WORD: if token != trainset.PAD_WORD:
print(inv_vocab[token], end=" ") print(inv_vocab[token], end=" ")
############################################################################## ##############################################################################
# Step 1: batching # Step 1: batching
# ---------------- # ----------------
# #
# The first step is to throw all the trees into one graph, using # The first step is to throw all the trees into one graph, using
# the :func:`~dgl.batched_graph.batch` API. # the :func:`~dgl.batched_graph.batch` API.
# #
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
graph = dgl.batch(tiny_sst) graph = dgl.batch(tiny_sst)
def plot_tree(g): def plot_tree(g):
# this plot requires pygraphviz package # this plot requires pygraphviz package
pos = nx.nx_agraph.graphviz_layout(g, prog='dot') pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
nx.draw(g, pos, with_labels=False, node_size=10, nx.draw(g, pos, with_labels=False, node_size=10,
node_color=[[.5, .5, .5]], arrowsize=4) node_color=[[.5, .5, .5]], arrowsize=4)
plt.show() plt.show()
plot_tree(graph.to_networkx()) plot_tree(graph.to_networkx())
############################################################################## ##############################################################################
# You can read more about the definition of :func:`~dgl.batched_graph.batch` # You can read more about the definition of :func:`~dgl.batched_graph.batch`
# (by clicking the API), or can skip ahead to the next step: # (by clicking the API), or can skip ahead to the next step:
# #
# .. note:: # .. note::
# #
# **Definition**: a :class:`~dgl.batched_graph.BatchedDGLGraph` is a # **Definition**: a :class:`~dgl.batched_graph.BatchedDGLGraph` is a
# :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s. # :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\ s.
# #
# - The union includes all the nodes, # - The union includes all the nodes,
# edges, and their features. The order of nodes, edges and features are # edges, and their features. The order of nodes, edges and features are
# preserved. # preserved.
# #
# - Given that we have :math:`V_i` nodes for graph # - Given that we have :math:`V_i` nodes for graph
# :math:`\mathcal{G}_i`, the node ID :math:`j` in graph # :math:`\mathcal{G}_i`, the node ID :math:`j` in graph
# :math:`\mathcal{G}_i` correspond to node ID # :math:`\mathcal{G}_i` correspond to node ID
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph. # :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
# #
# - Therefore, performing feature transformation and message passing on # - Therefore, performing feature transformation and message passing on
# ``BatchedDGLGraph`` is equivalent to doing those # ``BatchedDGLGraph`` is equivalent to doing those
# on all ``DGLGraph`` constituents in parallel. # on all ``DGLGraph`` constituents in parallel.
# #
# - Duplicate references to the same graph are # - Duplicate references to the same graph are
# treated as deep copies; the nodes, edges, and features are duplicated, # treated as deep copies; the nodes, edges, and features are duplicated,
# and mutation on one reference does not affect the other. # and mutation on one reference does not affect the other.
# - Currently, ``BatchedDGLGraph`` is immutable in # - Currently, ``BatchedDGLGraph`` is immutable in
# graph structure (i.e. one can't add # graph structure (i.e. one can't add
# nodes and edges to it). We need to support mutable batched graphs in # nodes and edges to it). We need to support mutable batched graphs in
# (far) future. # (far) future.
# - The ``BatchedDGLGraph`` keeps track of the meta # - The ``BatchedDGLGraph`` keeps track of the meta
# information of the constituents so it can be # information of the constituents so it can be
# :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s. # :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
# #
# For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph` # For more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`
# module in DGL, you can click the class name. # module in DGL, you can click the class name.
# #
# Step 2: Tree-LSTM Cell with message-passing APIs # Step 2: Tree-LSTM Cell with message-passing APIs
# ------------------------------------------------ # ------------------------------------------------
# #
# The authors proposed two types of Tree LSTM: Child-Sum # The authors proposed two types of Tree LSTM: Child-Sum
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus # Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial we focus
# on applying *Binary* Tree-LSTM to binarized constituency trees(this # on applying *Binary* Tree-LSTM to binarized constituency trees(this
# application is also known as *Constituency Tree-LSTM*). We use PyTorch # application is also known as *Constituency Tree-LSTM*). We use PyTorch
# as our backend framework to set up the network. # as our backend framework to set up the network.
# #
# In `N`-ary Tree LSTM, each unit at node :math:`j` maintains a hidden # In `N`-ary Tree LSTM, each unit at node :math:`j` maintains a hidden
# representation :math:`h_j` and a memory cell :math:`c_j`. The unit # representation :math:`h_j` and a memory cell :math:`c_j`. The unit
# :math:`j` takes the input vector :math:`x_j` and the hidden # :math:`j` takes the input vector :math:`x_j` and the hidden
# representations of the their child units: :math:`h_{jl}, 1\leq l\leq N` as # representations of the their child units: :math:`h_{jl}, 1\leq l\leq N` as
# input, then update its new hidden representation :math:`h_j` and memory # input, then update its new hidden representation :math:`h_j` and memory
# cell :math:`c_j` by: # cell :math:`c_j` by:
# #
# .. math:: # .. math::
# #
# i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\ # i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\
# f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\ # f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\
# o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\ # o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\
# u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\ # u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\
# c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\ # c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\
# h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\ # h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\
# #
# It can be decomposed into three phases: ``message_func``, # It can be decomposed into three phases: ``message_func``,
# ``reduce_func`` and ``apply_node_func``. # ``reduce_func`` and ``apply_node_func``.
# #
# .. note:: # .. note::
# ``apply_node_func`` is a new node UDF we have not introduced before. In # ``apply_node_func`` is a new node UDF we have not introduced before. In
# ``apply_node_func``, user specifies what to do with node features, # ``apply_node_func``, user specifies what to do with node features,
# without considering edge features and messages. In Tree-LSTM case, # without considering edge features and messages. In Tree-LSTM case,
# ``apply_node_func`` is a must, since there exists (leaf) nodes with # ``apply_node_func`` is a must, since there exists (leaf) nodes with
# :math:`0` incoming edges, which would not be updated via # :math:`0` incoming edges, which would not be updated via
# ``reduce_func``. # ``reduce_func``.
# #
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
class TreeLSTMCell(nn.Module): class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__() super(TreeLSTMCell, self).__init__()
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False) self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False) self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size)) self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, 2 * h_size) self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges): def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']} return {'h': edges.src['h'], 'c': edges.src['c']}
def reduce_func(self, nodes): def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4) # concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
# equation (2) # equation (2)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
# second term of equation (5) # second term of equation (5)
c = th.sum(f * nodes.mailbox['c'], 1) c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c} return {'iou': self.U_iou(h_cat), 'c': c}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
# equation (1), (3), (4) # equation (1), (3), (4)
iou = nodes.data['iou'] + self.b_iou iou = nodes.data['iou'] + self.b_iou
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5) # equation (5)
c = i * u + nodes.data['c'] c = i * u + nodes.data['c']
# equation (6) # equation (6)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {'h' : h, 'c' : c}
############################################################################## ##############################################################################
# Step 3: define traversal # Step 3: define traversal
# ------------------------ # ------------------------
# #
# After defining the message passing functions, we then need to induce the # After defining the message passing functions, we then need to induce the
# right order to trigger them. This is a significant departure from models # right order to trigger them. This is a significant departure from models
# such as GCN, where all nodes are pulling messages from upstream ones # such as GCN, where all nodes are pulling messages from upstream ones
# *simultaneously*. # *simultaneously*.
# #
# In the case of Tree-LSTM, messages start from leaves of the tree, and # In the case of Tree-LSTM, messages start from leaves of the tree, and
# propogate/processed upwards until they reach the roots. A visulization # propogate/processed upwards until they reach the roots. A visulization
# is as follows: # is as follows:
# #
# .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif # .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif
# :alt: # :alt:
# #
# DGL defines a generator to perform the topological sort, each item is a # DGL defines a generator to perform the topological sort, each item is a
# tensor recording the nodes from bottom level to the roots. One can # tensor recording the nodes from bottom level to the roots. One can
# appreciate the degree of parallelism by inspecting the difference of the # appreciate the degree of parallelism by inspecting the difference of the
# followings: # followings:
# #
print('Traversing one tree:') print('Traversing one tree:')
print(dgl.topological_nodes_generator(a_tree)) print(dgl.topological_nodes_generator(a_tree))
print('Traversing many trees at the same time:') print('Traversing many trees at the same time:')
print(dgl.topological_nodes_generator(graph)) print(dgl.topological_nodes_generator(graph))
############################################################################## ##############################################################################
# We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing: # We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
import dgl.function as fn import dgl.function as fn
import torch as th import torch as th
graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1) graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
graph.register_message_func(fn.copy_src('a', 'a')) graph.register_message_func(fn.copy_src('a', 'a'))
graph.register_reduce_func(fn.sum('a', 'a')) graph.register_reduce_func(fn.sum('a', 'a'))
traversal_order = dgl.topological_nodes_generator(graph) traversal_order = dgl.topological_nodes_generator(graph)
graph.prop_nodes(traversal_order) graph.prop_nodes(traversal_order)
# the following is a syntax sugar that does the same # the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph) # dgl.prop_nodes_topo(graph)
############################################################################## ##############################################################################
# .. note:: # .. note::
# #
# Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a # Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a
# `message_func` and `reduce_func` in advance, here we use built-in # `message_func` and `reduce_func` in advance, here we use built-in
# copy-from-source and sum function as our message function and reduce # copy-from-source and sum function as our message function and reduce
# function for demonstration. # function for demonstration.
# #
# Putting it together # Putting it together
# ------------------- # -------------------
# #
# Here is the complete code that specifies the ``Tree-LSTM`` class: # Here is the complete code that specifies the ``Tree-LSTM`` class:
# #
class TreeLSTM(nn.Module): class TreeLSTM(nn.Module):
def __init__(self, def __init__(self,
num_vocabs, num_vocabs,
x_size, x_size,
h_size, h_size,
num_classes, num_classes,
dropout, dropout,
pretrained_emb=None): pretrained_emb=None):
super(TreeLSTM, self).__init__() super(TreeLSTM, self).__init__()
self.x_size = x_size self.x_size = x_size
self.embedding = nn.Embedding(num_vocabs, x_size) self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None: if pretrained_emb is not None:
print('Using glove') print('Using glove')
self.embedding.weight.data.copy_(pretrained_emb) self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes) self.linear = nn.Linear(h_size, num_classes)
self.cell = TreeLSTMCell(x_size, h_size) self.cell = TreeLSTMCell(x_size, h_size)
def forward(self, batch, h, c): def forward(self, batch, h, c):
"""Compute tree-lstm prediction given a batch. """Compute tree-lstm prediction given a batch.
Parameters Parameters
---------- ----------
batch : dgl.data.SSTBatch batch : dgl.data.SSTBatch
The data batch. The data batch.
h : Tensor h : Tensor
Initial hidden state. Initial hidden state.
c : Tensor c : Tensor
Initial cell state. Initial cell state.
Returns Returns
------- -------
logits : Tensor logits : Tensor
The prediction of each node. The prediction of each node.
""" """
g = batch.graph g = batch.graph
g.register_message_func(self.cell.message_func) g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func) g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_node_func) g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h g.ndata['h'] = h
g.ndata['c'] = c g.ndata['c'] = c
# propagate # propagate
dgl.prop_nodes_topo(g) dgl.prop_nodes_topo(g)
# compute logits # compute logits
h = self.dropout(g.ndata.pop('h')) h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h) logits = self.linear(h)
return logits return logits
############################################################################## ##############################################################################
# Main Loop # Main Loop
# --------- # ---------
# #
# Finally, we could write a training paradigm in PyTorch: # Finally, we could write a training paradigm in PyTorch:
# #
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.nn.functional as F import torch.nn.functional as F
device = th.device('cpu') device = th.device('cpu')
# hyper parameters # hyper parameters
x_size = 256 x_size = 256
h_size = 256 h_size = 256
dropout = 0.5 dropout = 0.5
lr = 0.05 lr = 0.05
weight_decay = 1e-4 weight_decay = 1e-4
epochs = 10 epochs = 10
# create the model # create the model
model = TreeLSTM(trainset.num_vocabs, model = TreeLSTM(trainset.num_vocabs,
x_size, x_size,
h_size, h_size,
trainset.num_classes, trainset.num_classes,
dropout) dropout)
print(model) print(model)
# create the optimizer # create the optimizer
optimizer = th.optim.Adagrad(model.parameters(), optimizer = th.optim.Adagrad(model.parameters(),
lr=lr, lr=lr,
weight_decay=weight_decay) weight_decay=weight_decay)
train_loader = DataLoader(dataset=tiny_sst, train_loader = DataLoader(dataset=tiny_sst,
batch_size=5, batch_size=5,
collate_fn=data.SST.batcher(device), collate_fn=data.SST.batcher(device),
shuffle=False, shuffle=False,
num_workers=0) num_workers=0)
# training loop # training loop
for epoch in range(epochs): for epoch in range(epochs):
for step, batch in enumerate(train_loader): for step, batch in enumerate(train_loader):
g = batch.graph g = batch.graph
n = g.number_of_nodes() n = g.number_of_nodes()
h = th.zeros((n, h_size)) h = th.zeros((n, h_size))
c = th.zeros((n, h_size)) c = th.zeros((n, h_size))
logits = model(batch, h, c) logits = model(batch, h, c)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='sum') loss = F.nll_loss(logp, batch.label, reduction='sum')
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
pred = th.argmax(logits, 1) pred = th.argmax(logits, 1)
acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label) acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format( print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
epoch, step, loss.item(), acc)) epoch, step, loss.item(), acc))
############################################################################## ##############################################################################
# To train the model on full dataset with different settings(CPU/GPU, # To train the model on full dataset with different settings(CPU/GPU,
# etc.), please refer to our repo's # etc.), please refer to our repo's
# `example <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/tree_lstm>`__. # `example <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/tree_lstm>`__.
# Besides, we also provide an implementation of the Child-Sum Tree LSTM. # Besides, we also provide an implementation of the Child-Sum Tree LSTM.
.. _tutorials2-index:
Dealing with many small graphs
------------------------------
* **Tree-LSTM** `[paper] <https://arxiv.org/abs/1503.00075>`__ `[tutorial] <models/3_tree-lstm.html>`__
`[code] <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/tree_lstm/tree_lstm.py>`__:
sentences of natural languages have inherent structures, which are thrown away
by treating them simply as sequences. Tree-LSTM is a powerful model that learns
the representation by leveraging prior syntactic structures (e.g. parse-tree).
The challenge to train it well is that simply by padding a sentence to the
maximum length no longer works, since trees of different sentences have
different sizes and topologies. DGL solves this problem by throwing the trees
into a bigger "container" graph, and use message-passing to explore maximum
parallelism. The key API we use is batching.
""" """
.. _model-dgmg: .. _model-dgmg:
Tutorial for Generative Models of Graphs Tutorial for Generative Models of Graphs
=========================================== ===========================================
**Author**: `Mufei Li <https://github.com/mufeili>`_, **Author**: `Mufei Li <https://github.com/mufeili>`_,
`Lingfan Yu <https://github.com/ylfdq1118>`_, Zheng Zhang `Lingfan Yu <https://github.com/ylfdq1118>`_, Zheng Zhang
""" """
############################################################################## ##############################################################################
# #
# In earlier tutorials we have seen how learned embedding of a graph and/or # In earlier tutorials we have seen how learned embedding of a graph and/or
# a node allow applications such as `semi-supervised classification for nodes # a node allow applications such as `semi-supervised classification for nodes
# <http://docs.dgl.ai/tutorials/models/1_gcn.html#sphx-glr-tutorials-models-1-gcn-py>`__ # <http://docs.dgl.ai/tutorials/models/1_gcn.html#sphx-glr-tutorials-models-1-gcn-py>`__
# or `sentiment analysis # or `sentiment analysis
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__. # <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__.
# Wouldn't it be interesting to predict the future evolution of the graph and # Wouldn't it be interesting to predict the future evolution of the graph and
# perform the analysis iteratively? # perform the analysis iteratively?
# #
# We will need to generate a variety of graph samples, in other words, we need # We will need to generate a variety of graph samples, in other words, we need
# **generative models** of graphs. Instead of and/or in addition to learning # **generative models** of graphs. Instead of and/or in addition to learning
# node and edge features, we want to model the distribution of arbitrary graphs. # node and edge features, we want to model the distribution of arbitrary graphs.
# While general generative models can model the density function explicitly and # While general generative models can model the density function explicitly and
# implicitly and generate samples at once or sequentially, we will only focus # implicitly and generate samples at once or sequentially, we will only focus
# on explicit generative models for sequential generation here. Typical applications # on explicit generative models for sequential generation here. Typical applications
# include drug/material discovery, chemical processes, proteomics, etc. # include drug/material discovery, chemical processes, proteomics, etc.
# #
# Introduction # Introduction
# -------------------- # --------------------
# The primitive actions of mutating a graph in DGL are nothing more than ``add_nodes`` # The primitive actions of mutating a graph in DGL are nothing more than ``add_nodes``
# and ``add_edges``. That is, if we were to draw a circle of 3 nodes, # and ``add_edges``. That is, if we were to draw a circle of 3 nodes,
# #
# .. figure:: https://user-images.githubusercontent.com/19576924/48313438-78baf000-e5f7-11e8-931e-cd00ab34fa50.gif # .. figure:: https://user-images.githubusercontent.com/19576924/48313438-78baf000-e5f7-11e8-931e-cd00ab34fa50.gif
# :alt: # :alt:
# #
# we can simply write the code as: # we can simply write the code as:
# #
import dgl import dgl
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(1) # Add node 0 g.add_nodes(1) # Add node 0
g.add_nodes(1) # Add node 1 g.add_nodes(1) # Add node 1
# Edges in DGLGraph are directed by default. # Edges in DGLGraph are directed by default.
# For undirected edges, we add edges for both directions. # For undirected edges, we add edges for both directions.
g.add_edges([1, 0], [0, 1]) # Add edges (1, 0), (0, 1) g.add_edges([1, 0], [0, 1]) # Add edges (1, 0), (0, 1)
g.add_nodes(1) # Add node 2 g.add_nodes(1) # Add node 2
g.add_edges([2, 1], [1, 2]) # Add edges (2, 1), (1, 2) g.add_edges([2, 1], [1, 2]) # Add edges (2, 1), (1, 2)
g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2) g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2)
####################################################################################### #######################################################################################
# Real-world graphs are much more complex. There are many families of graphs, # Real-world graphs are much more complex. There are many families of graphs,
# with different sizes, topologies, node types, edge types, and the possibility # with different sizes, topologies, node types, edge types, and the possibility
# of multigraphs. Besides, a same graph can be generated in many different # of multigraphs. Besides, a same graph can be generated in many different
# orders. Regardless, the generative process entails a few steps: # orders. Regardless, the generative process entails a few steps:
# #
# - Encode a changing graph, # - Encode a changing graph,
# - Perform actions stochastically, # - Perform actions stochastically,
# - Collect error signals and optimize the model parameters (If we are training) # - Collect error signals and optimize the model parameters (If we are training)
# #
# When it comes to implementation, another important aspect is speed: how do we # When it comes to implementation, another important aspect is speed: how do we
# parallelize the computation given that generating a graph is fundamentally a # parallelize the computation given that generating a graph is fundamentally a
# sequential process? # sequential process?
# #
# .. note:: # .. note::
# #
# To be sure, this is not necessarily a hard constraint, one can imagine # To be sure, this is not necessarily a hard constraint, one can imagine
# that subgraphs can be built in parallel and then get assembled. But we # that subgraphs can be built in parallel and then get assembled. But we
# will restrict ourselves to the sequential processes for this tutorial. # will restrict ourselves to the sequential processes for this tutorial.
# #
# In tutorial, we will first focus on how to train and generate one graph at # In tutorial, we will first focus on how to train and generate one graph at
# a time, exploring parallelism within the graph embedding operation, an # a time, exploring parallelism within the graph embedding operation, an
# essential building block. We will end with a simple optimization that # essential building block. We will end with a simple optimization that
# delivers a 2x speedup by batching across graphs. # delivers a 2x speedup by batching across graphs.
# #
# DGMG: the main flow # DGMG: the main flow
# -------------------- # --------------------
# We pick DGMG ( # We pick DGMG (
# `Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__ # `Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
# ) as an exercise to implement a graph generative model using DGL, primarily # ) as an exercise to implement a graph generative model using DGL, primarily
# because its algorithmic framework is general but also challenging to parallelize. # because its algorithmic framework is general but also challenging to parallelize.
# #
# .. note:: # .. note::
# #
# While it's possible for DGMG to handle complex graphs with typed nodes, # While it's possible for DGMG to handle complex graphs with typed nodes,
# typed edges and multigraphs, we only present a simplified version of it # typed edges and multigraphs, we only present a simplified version of it
# for generating graph topologies. # for generating graph topologies.
# #
# DGMG generates a graph by following a state machine, which is basically a # DGMG generates a graph by following a state machine, which is basically a
# two-level loop: generate one node at a time, and connect it to a subset of # two-level loop: generate one node at a time, and connect it to a subset of
# the existing nodes, one at a time. This is similar to language modeling: the # the existing nodes, one at a time. This is similar to language modeling: the
# generative process is an iterative one that emits one word/character/sentence # generative process is an iterative one that emits one word/character/sentence
# at a time, conditioned on the sequence generated so far. # at a time, conditioned on the sequence generated so far.
# #
# At each time step, we either # At each time step, we either
# - add a new node to the graph, or # - add a new node to the graph, or
# - select two existing nodes and add an edge between them # - select two existing nodes and add an edge between them
# #
# .. figure:: https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png # .. figure:: https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png
# :alt: # :alt:
# #
# The Python code will look as follows; in fact, this is *exactly* how inference # The Python code will look as follows; in fact, this is *exactly* how inference
# with DGMG is implemented in DGL: # with DGMG is implemented in DGL:
# #
def forward_inference(self): def forward_inference(self):
stop = self.add_node_and_update() stop = self.add_node_and_update()
while (not stop) and (self.g.number_of_nodes() < self.v_max + 1): while (not stop) and (self.g.number_of_nodes() < self.v_max + 1):
num_trials = 0 num_trials = 0
to_add_edge = self.add_edge_or_not() to_add_edge = self.add_edge_or_not()
while to_add_edge and (num_trials < self.g.number_of_nodes() - 1): while to_add_edge and (num_trials < self.g.number_of_nodes() - 1):
self.choose_dest_and_update() self.choose_dest_and_update()
num_trials += 1 num_trials += 1
to_add_edge = self.add_edge_or_not() to_add_edge = self.add_edge_or_not()
stop = self.add_node_and_update() stop = self.add_node_and_update()
return self.g return self.g
####################################################################################### #######################################################################################
# Assume we have a pre-trained model for generating cycles of nodes 10 - 20, let's see # Assume we have a pre-trained model for generating cycles of nodes 10 - 20, let's see
# how it generates a cycle on the fly during inference. You can also use the code below # how it generates a cycle on the fly during inference. You can also use the code below
# for creating animation with your own model. # for creating animation with your own model.
# #
# :: # ::
# #
# import torch # import torch
# import matplotlib.animation as animation # import matplotlib.animation as animation
# import matplotlib.pyplot as plt # import matplotlib.pyplot as plt
# import networkx as nx # import networkx as nx
# from copy import deepcopy # from copy import deepcopy
# #
# if __name__ == '__main__': # if __name__ == '__main__':
# # pre-trained model saved with path ./model.pth # # pre-trained model saved with path ./model.pth
# model = torch.load('./model.pth') # model = torch.load('./model.pth')
# model.eval() # model.eval()
# g = model() # g = model()
# #
# src_list = g.edges()[1] # src_list = g.edges()[1]
# dest_list = g.edges()[0] # dest_list = g.edges()[0]
# #
# evolution = [] # evolution = []
# #
# nx_g = nx.Graph() # nx_g = nx.Graph()
# evolution.append(deepcopy(nx_g)) # evolution.append(deepcopy(nx_g))
# #
# for i in range(0, len(src_list), 2): # for i in range(0, len(src_list), 2):
# src = src_list[i].item() # src = src_list[i].item()
# dest = dest_list[i].item() # dest = dest_list[i].item()
# if src not in nx_g.nodes(): # if src not in nx_g.nodes():
# nx_g.add_node(src) # nx_g.add_node(src)
# evolution.append(deepcopy(nx_g)) # evolution.append(deepcopy(nx_g))
# if dest not in nx_g.nodes(): # if dest not in nx_g.nodes():
# nx_g.add_node(dest) # nx_g.add_node(dest)
# evolution.append(deepcopy(nx_g)) # evolution.append(deepcopy(nx_g))
# nx_g.add_edges_from([(src, dest), (dest, src)]) # nx_g.add_edges_from([(src, dest), (dest, src)])
# evolution.append(deepcopy(nx_g)) # evolution.append(deepcopy(nx_g))
# #
# def animate(i): # def animate(i):
# ax.cla() # ax.cla()
# g_t = evolution[i] # g_t = evolution[i]
# nx.draw_circular(g_t, with_labels=True, ax=ax, # nx.draw_circular(g_t, with_labels=True, ax=ax,
# node_color=['#FEBD69'] * g_t.number_of_nodes()) # node_color=['#FEBD69'] * g_t.number_of_nodes())
# #
# fig, ax = plt.subplots() # fig, ax = plt.subplots()
# ani = animation.FuncAnimation(fig, animate, # ani = animation.FuncAnimation(fig, animate,
# frames=len(evolution), # frames=len(evolution),
# interval=600) # interval=600)
# #
# .. figure:: https://user-images.githubusercontent.com/19576924/48928548-2644d200-ef1b-11e8-8591-da93345382ad.gif # .. figure:: https://user-images.githubusercontent.com/19576924/48928548-2644d200-ef1b-11e8-8591-da93345382ad.gif
# :alt: # :alt:
# #
# DGMG: optimization objective # DGMG: optimization objective
# ------------------------------ # ------------------------------
# Similar to language modeling, DGMG trains the model with *behavior cloning*, # Similar to language modeling, DGMG trains the model with *behavior cloning*,
# or *teacher forcing*. Let's assume for each graph there exists a sequence of # or *teacher forcing*. Let's assume for each graph there exists a sequence of
# *oracle actions* :math:`a_{1},\cdots,a_{T}` that generates it. What the model # *oracle actions* :math:`a_{1},\cdots,a_{T}` that generates it. What the model
# does is to follow these actions, compute the joint probabilities of such # does is to follow these actions, compute the joint probabilities of such
# action sequences, and maximize them. # action sequences, and maximize them.
# #
# By chain rule, the probability of taking :math:`a_{1},\cdots,a_{T}` is: # By chain rule, the probability of taking :math:`a_{1},\cdots,a_{T}` is:
# #
# .. math:: # .. math::
# #
# p(a_{1},\cdots, a_{T}) = p(a_{1})p(a_{2}|a_{1})\cdots p(a_{T}|a_{1},\cdots,a_{T-1}).\\ # p(a_{1},\cdots, a_{T}) = p(a_{1})p(a_{2}|a_{1})\cdots p(a_{T}|a_{1},\cdots,a_{T-1}).\\
# #
# The optimization objective is then simply the typical MLE loss: # The optimization objective is then simply the typical MLE loss:
# #
# .. math:: # .. math::
# #
# -\log p(a_{1},\cdots,a_{T})=-\sum_{t=1}^{T}\log p(a_{t}|a_{1},\cdots, a_{t-1}).\\ # -\log p(a_{1},\cdots,a_{T})=-\sum_{t=1}^{T}\log p(a_{t}|a_{1},\cdots, a_{t-1}).\\
# #
def forward_train(self, actions): def forward_train(self, actions):
""" """
- actions: list - actions: list
- Contains a_1, ..., a_T described above - Contains a_1, ..., a_T described above
- self.prepare_for_train() - self.prepare_for_train()
- Initializes self.action_step to be 0, which will get - Initializes self.action_step to be 0, which will get
incremented by 1 everytime it is called. incremented by 1 everytime it is called.
- Initializes objects recording log p(a_t|a_1,...a_{t-1}) - Initializes objects recording log p(a_t|a_1,...a_{t-1})
Returns Returns
------- -------
- self.get_log_prob(): log p(a_1, ..., a_T) - self.get_log_prob(): log p(a_1, ..., a_T)
""" """
self.prepare_for_train() self.prepare_for_train()
stop = self.add_node_and_update(a=actions[self.action_step]) stop = self.add_node_and_update(a=actions[self.action_step])
while not stop: while not stop:
to_add_edge = self.add_edge_or_not(a=actions[self.action_step]) to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
while to_add_edge: while to_add_edge:
self.choose_dest_and_update(a=actions[self.action_step]) self.choose_dest_and_update(a=actions[self.action_step])
to_add_edge = self.add_edge_or_not(a=actions[self.action_step]) to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
stop = self.add_node_and_update(a=actions[self.action_step]) stop = self.add_node_and_update(a=actions[self.action_step])
return self.get_log_prob() return self.get_log_prob()
####################################################################################### #######################################################################################
# The key difference between ``forward_train`` and ``forward_inference`` is # The key difference between ``forward_train`` and ``forward_inference`` is
# that the training process takes oracle actions as input, and returns log # that the training process takes oracle actions as input, and returns log
# probabilities for evaluating the loss. # probabilities for evaluating the loss.
# #
# DGMG: the implementation # DGMG: the implementation
# -------------------------- # --------------------------
# The ``DGMG`` class # The ``DGMG`` class
# `````````````````````````` # ``````````````````````````
# Below one can find the skeleton code for the model. We will gradually # Below one can find the skeleton code for the model. We will gradually
# fill in the details for each function. # fill in the details for each function.
# #
import torch.nn as nn import torch.nn as nn
class DGMGSkeleton(nn.Module): class DGMGSkeleton(nn.Module):
def __init__(self, v_max): def __init__(self, v_max):
""" """
Parameters Parameters
---------- ----------
v_max: int v_max: int
Max number of nodes considered Max number of nodes considered
""" """
super(DGMGSkeleton, self).__init__() super(DGMGSkeleton, self).__init__()
# Graph configuration # Graph configuration
self.v_max = v_max self.v_max = v_max
def add_node_and_update(self, a=None): def add_node_and_update(self, a=None):
"""Decide if to add a new node. """Decide if to add a new node.
If a new node should be added, update the graph.""" If a new node should be added, update the graph."""
return NotImplementedError return NotImplementedError
def add_edge_or_not(self, a=None): def add_edge_or_not(self, a=None):
"""Decide if a new edge should be added.""" """Decide if a new edge should be added."""
return NotImplementedError return NotImplementedError
def choose_dest_and_update(self, a=None): def choose_dest_and_update(self, a=None):
"""Choose destination and connect it to the latest node. """Choose destination and connect it to the latest node.
Add edges for both directions and update the graph.""" Add edges for both directions and update the graph."""
return NotImplementedError return NotImplementedError
def forward_train(self, actions): def forward_train(self, actions):
"""Forward at training time. It records the probability """Forward at training time. It records the probability
of generating a ground truth graph following the actions.""" of generating a ground truth graph following the actions."""
return NotImplementedError return NotImplementedError
def forward_inference(self): def forward_inference(self):
"""Forward at inference time. """Forward at inference time.
It generates graphs on the fly.""" It generates graphs on the fly."""
return NotImplementedError return NotImplementedError
def forward(self, actions=None): def forward(self, actions=None):
# The graph we will work on # The graph we will work on
self.g = dgl.DGLGraph() self.g = dgl.DGLGraph()
# If there are some features for nodes and edges, # If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges. # zero tensors will be set for those of new nodes and edges.
self.g.set_n_initializer(dgl.frame.zero_initializer) self.g.set_n_initializer(dgl.frame.zero_initializer)
self.g.set_e_initializer(dgl.frame.zero_initializer) self.g.set_e_initializer(dgl.frame.zero_initializer)
if self.training: if self.training:
return self.forward_train(actions=actions) return self.forward_train(actions=actions)
else: else:
return self.forward_inference() return self.forward_inference()
####################################################################################### #######################################################################################
# Encoding a dynamic graph # Encoding a dynamic graph
# `````````````````````````` # ``````````````````````````
# All the actions generating a graph are sampled from probability # All the actions generating a graph are sampled from probability
# distributions. In order to do that, we must project the structured data, # distributions. In order to do that, we must project the structured data,
# namely the graph, onto an Euclidean space. The challenge is that such # namely the graph, onto an Euclidean space. The challenge is that such
# process, called *embedding*, needs to be repeated as the graphs mutate. # process, called *embedding*, needs to be repeated as the graphs mutate.
# #
# Graph Embedding # Graph Embedding
# '''''''''''''''''''''''''' # ''''''''''''''''''''''''''
# Let :math:`G=(V,E)` be an arbitrary graph. Each node :math:`v` has an # Let :math:`G=(V,E)` be an arbitrary graph. Each node :math:`v` has an
# embedding vector :math:`\textbf{h}_{v} \in \mathbb{R}^{n}`. Similarly, # embedding vector :math:`\textbf{h}_{v} \in \mathbb{R}^{n}`. Similarly,
# the graph has an embedding vector :math:`\textbf{h}_{G} \in \mathbb{R}^{k}`. # the graph has an embedding vector :math:`\textbf{h}_{G} \in \mathbb{R}^{k}`.
# Typically, :math:`k > n` since a graph contains more information than # Typically, :math:`k > n` since a graph contains more information than
# an individual node. # an individual node.
# #
# The graph embedding is a weighted sum of node embeddings under a linear # The graph embedding is a weighted sum of node embeddings under a linear
# transformation: # transformation:
# #
# .. math:: # .. math::
# #
# \textbf{h}_{G} =\sum_{v\in V}\text{Sigmoid}(g_m(\textbf{h}_{v}))f_{m}(\textbf{h}_{v}),\\ # \textbf{h}_{G} =\sum_{v\in V}\text{Sigmoid}(g_m(\textbf{h}_{v}))f_{m}(\textbf{h}_{v}),\\
# #
# The first term, :math:`\text{Sigmoid}(g_m(\textbf{h}_{v}))`, computes a # The first term, :math:`\text{Sigmoid}(g_m(\textbf{h}_{v}))`, computes a
# gating function and can be thought as how much the overall graph embedding # gating function and can be thought as how much the overall graph embedding
# attends on each node. The second term :math:`f_{m}:\mathbb{R}^{n}\rightarrow\mathbb{R}^{k}` # attends on each node. The second term :math:`f_{m}:\mathbb{R}^{n}\rightarrow\mathbb{R}^{k}`
# maps the node embeddings to the space of graph embeddings. # maps the node embeddings to the space of graph embeddings.
# #
# We implement graph embedding as a ``GraphEmbed`` class: # We implement graph embedding as a ``GraphEmbed`` class:
# #
import torch import torch
class GraphEmbed(nn.Module): class GraphEmbed(nn.Module):
def __init__(self, node_hidden_size): def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__() super(GraphEmbed, self).__init__()
# Setting from the paper # Setting from the paper
self.graph_hidden_size = 2 * node_hidden_size self.graph_hidden_size = 2 * node_hidden_size
# Embed graphs # Embed graphs
self.node_gating = nn.Sequential( self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1), nn.Linear(node_hidden_size, 1),
nn.Sigmoid() nn.Sigmoid()
) )
self.node_to_graph = nn.Linear(node_hidden_size, self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size) self.graph_hidden_size)
def forward(self, g): def forward(self, g):
if g.number_of_nodes() == 0: if g.number_of_nodes() == 0:
return torch.zeros(1, self.graph_hidden_size) return torch.zeros(1, self.graph_hidden_size)
else: else:
# Node features are stored as hv in ndata. # Node features are stored as hv in ndata.
hvs = g.ndata['hv'] hvs = g.ndata['hv']
return (self.node_gating(hvs) * return (self.node_gating(hvs) *
self.node_to_graph(hvs)).sum(0, keepdim=True) self.node_to_graph(hvs)).sum(0, keepdim=True)
####################################################################################### #######################################################################################
# Update node embeddings via graph propagation # Update node embeddings via graph propagation
# '''''''''''''''''''''''''''''''''''''''''''' # ''''''''''''''''''''''''''''''''''''''''''''
# #
# The mechanism of updating node embeddings in DGMG is similar to that for # The mechanism of updating node embeddings in DGMG is similar to that for
# graph convolutional networks. For a node :math:`v` in the graph, its # graph convolutional networks. For a node :math:`v` in the graph, its
# neighbor :math:`u` sends a message to it with # neighbor :math:`u` sends a message to it with
# #
# .. math:: # .. math::
# #
# \textbf{m}_{u\rightarrow v}=\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m},\\ # \textbf{m}_{u\rightarrow v}=\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m},\\
# #
# where :math:`\textbf{x}_{u,v}` is the embedding of the edge between # where :math:`\textbf{x}_{u,v}` is the embedding of the edge between
# :math:`u` and :math:`v`. # :math:`u` and :math:`v`.
# #
# After receiving messages from all its neighbors, :math:`v` summarizes them # After receiving messages from all its neighbors, :math:`v` summarizes them
# with a node activation vector # with a node activation vector
# #
# .. math:: # .. math::
# #
# \textbf{a}_{v} = \sum_{u: (u, v)\in E}\textbf{m}_{u\rightarrow v}\\ # \textbf{a}_{v} = \sum_{u: (u, v)\in E}\textbf{m}_{u\rightarrow v}\\
# #
# and use this information to update its own feature: # and use this information to update its own feature:
# #
# .. math:: # .. math::
# #
# \textbf{h}'_{v} = \textbf{GRU}(\textbf{h}_{v}, \textbf{a}_{v}).\\ # \textbf{h}'_{v} = \textbf{GRU}(\textbf{h}_{v}, \textbf{a}_{v}).\\
# #
# Performing all the operations above once for all nodes synchronously is # Performing all the operations above once for all nodes synchronously is
# called one round of graph propagation. The more rounds of graph propagation # called one round of graph propagation. The more rounds of graph propagation
# we perform, the longer distance messages travel throughout the graph. # we perform, the longer distance messages travel throughout the graph.
# #
# With dgl, we implement graph propagation with ``g.update_all``. Note that # With dgl, we implement graph propagation with ``g.update_all``. Note that
# the message notation here can be a bit confusing. While the authors refer # the message notation here can be a bit confusing. While the authors refer
# to :math:`\textbf{m}_{u\rightarrow v}` as messages, our message function # to :math:`\textbf{m}_{u\rightarrow v}` as messages, our message function
# below only passes :math:`\text{concat}([\textbf{h}_{u}, \textbf{x}_{u, v}])`. # below only passes :math:`\text{concat}([\textbf{h}_{u}, \textbf{x}_{u, v}])`.
# The operation :math:`\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m}` # The operation :math:`\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m}`
# is then performed across all edges at once for efficiency consideration. # is then performed across all edges at once for efficiency consideration.
# #
from functools import partial from functools import partial
class GraphProp(nn.Module): class GraphProp(nn.Module):
def __init__(self, num_prop_rounds, node_hidden_size): def __init__(self, num_prop_rounds, node_hidden_size):
super(GraphProp, self).__init__() super(GraphProp, self).__init__()
self.num_prop_rounds = num_prop_rounds self.num_prop_rounds = num_prop_rounds
# Setting from the paper # Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size self.node_activation_hidden_size = 2 * node_hidden_size
message_funcs = [] message_funcs = []
node_update_funcs = [] node_update_funcs = []
self.reduce_funcs = [] self.reduce_funcs = []
for t in range(num_prop_rounds): for t in range(num_prop_rounds):
# input being [hv, hu, xuv] # input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + 1, message_funcs.append(nn.Linear(2 * node_hidden_size + 1,
self.node_activation_hidden_size)) self.node_activation_hidden_size))
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t)) self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append( node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size, nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size)) node_hidden_size))
self.message_funcs = nn.ModuleList(message_funcs) self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs) self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges): def dgmg_msg(self, edges):
"""For an edge u->v, return concat([h_u, x_uv])""" """For an edge u->v, return concat([h_u, x_uv])"""
return {'m': torch.cat([edges.src['hv'], return {'m': torch.cat([edges.src['hv'],
edges.data['he']], edges.data['he']],
dim=1)} dim=1)}
def dgmg_reduce(self, nodes, round): def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv'] hv_old = nodes.data['hv']
m = nodes.mailbox['m'] m = nodes.mailbox['m']
message = torch.cat([ message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2) hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1) node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation} return {'a': node_activation}
def forward(self, g): def forward(self, g):
if g.number_of_edges() > 0: if g.number_of_edges() > 0:
for t in range(self.num_prop_rounds): for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg, g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t]) reduce_func=self.reduce_funcs[t])
g.ndata['hv'] = self.node_update_funcs[t]( g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv']) g.ndata['a'], g.ndata['hv'])
####################################################################################### #######################################################################################
# Actions # Actions
# `````````````````````````` # ``````````````````````````
# All actions are sampled from distributions parameterized using neural nets # All actions are sampled from distributions parameterized using neural nets
# and we introduce them in turn. # and we introduce them in turn.
# #
# Action 1: add nodes # Action 1: add nodes
# '''''''''''''''''''''''''' # ''''''''''''''''''''''''''
# #
# Given the graph embedding vector :math:`\textbf{h}_{G}`, we evaluate # Given the graph embedding vector :math:`\textbf{h}_{G}`, we evaluate
# #
# .. math:: # .. math::
# #
# \text{Sigmoid}(\textbf{W}_{\text{add node}}\textbf{h}_{G}+b_{\text{add node}}),\\ # \text{Sigmoid}(\textbf{W}_{\text{add node}}\textbf{h}_{G}+b_{\text{add node}}),\\
# #
# which is then used to parametrize a Bernoulli distribution for deciding whether # which is then used to parametrize a Bernoulli distribution for deciding whether
# to add a new node. # to add a new node.
# #
# If a new node is to be added, we initialize its feature with # If a new node is to be added, we initialize its feature with
# #
# .. math:: # .. math::
# #
# \textbf{W}_{\text{init}}\text{concat}([\textbf{h}_{\text{init}} , \textbf{h}_{G}])+\textbf{b}_{\text{init}},\\ # \textbf{W}_{\text{init}}\text{concat}([\textbf{h}_{\text{init}} , \textbf{h}_{G}])+\textbf{b}_{\text{init}},\\
# #
# where :math:`\textbf{h}_{\text{init}}` is a learnable embedding module for # where :math:`\textbf{h}_{\text{init}}` is a learnable embedding module for
# untyped nodes. # untyped nodes.
# #
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributions import Bernoulli from torch.distributions import Bernoulli
def bernoulli_action_log_prob(logit, action): def bernoulli_action_log_prob(logit, action):
"""Calculate the log p of an action with respect to a Bernoulli """Calculate the log p of an action with respect to a Bernoulli
distribution. Use logit rather than prob for numerical stability.""" distribution. Use logit rather than prob for numerical stability."""
if action == 0: if action == 0:
return F.logsigmoid(-logit) return F.logsigmoid(-logit)
else: else:
return F.logsigmoid(logit) return F.logsigmoid(logit)
class AddNode(nn.Module): class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size): def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__() super(AddNode, self).__init__()
self.graph_op = {'embed': graph_embed_func} self.graph_op = {'embed': graph_embed_func}
self.stop = 1 self.stop = 1
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1) self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
# If to add a node, initialize its hv # If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size) self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \ self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size, graph_embed_func.graph_hidden_size,
node_hidden_size) node_hidden_size)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size) self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
def _initialize_node_repr(self, g, node_type, graph_embed): def _initialize_node_repr(self, g, node_type, graph_embed):
"""Whenver a node is added, initialize its representation.""" """Whenver a node is added, initialize its representation."""
num_nodes = g.number_of_nodes() num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv( hv_init = self.initialize_hv(
torch.cat([ torch.cat([
self.node_type_embed(torch.LongTensor([node_type])), self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1)) graph_embed], dim=1))
g.nodes[num_nodes - 1].data['hv'] = hv_init g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
def forward(self, g, action=None): def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g) graph_embed = self.graph_op['embed'](g)
logit = self.add_node(graph_embed) logit = self.add_node(graph_embed)
prob = torch.sigmoid(logit) prob = torch.sigmoid(logit)
if not self.training: if not self.training:
action = Bernoulli(prob).sample().item() action = Bernoulli(prob).sample().item()
stop = bool(action == self.stop) stop = bool(action == self.stop)
if not stop: if not stop:
g.add_nodes(1) g.add_nodes(1)
self._initialize_node_repr(g, action, graph_embed) self._initialize_node_repr(g, action, graph_embed)
if self.training: if self.training:
sample_log_prob = bernoulli_action_log_prob(logit, action) sample_log_prob = bernoulli_action_log_prob(logit, action)
self.log_prob.append(sample_log_prob) self.log_prob.append(sample_log_prob)
return stop return stop
####################################################################################### #######################################################################################
# Action 2: add edges # Action 2: add edges
# '''''''''''''''''''''''''' # ''''''''''''''''''''''''''
# #
# Given the graph embedding vector :math:`\textbf{h}_{G}` and the node # Given the graph embedding vector :math:`\textbf{h}_{G}` and the node
# embedding vector :math:`\textbf{h}_{v}` for the latest node :math:`v`, # embedding vector :math:`\textbf{h}_{v}` for the latest node :math:`v`,
# we evaluate # we evaluate
# #
# .. math:: # .. math::
# #
# \text{Sigmoid}(\textbf{W}_{\text{add edge}}\text{concat}([\textbf{h}_{G}, \textbf{h}_{v}])+b_{\text{add edge}}),\\ # \text{Sigmoid}(\textbf{W}_{\text{add edge}}\text{concat}([\textbf{h}_{G}, \textbf{h}_{v}])+b_{\text{add edge}}),\\
# #
# which is then used to parametrize a Bernoulli distribution for deciding # which is then used to parametrize a Bernoulli distribution for deciding
# whether to add a new edge starting from :math:`v`. # whether to add a new edge starting from :math:`v`.
# #
class AddEdge(nn.Module): class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size): def __init__(self, graph_embed_func, node_hidden_size):
super(AddEdge, self).__init__() super(AddEdge, self).__init__()
self.graph_op = {'embed': graph_embed_func} self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \ self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \
node_hidden_size, 1) node_hidden_size, 1)
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
def forward(self, g, action=None): def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g) graph_embed = self.graph_op['embed'](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv'] src_embed = g.nodes[g.number_of_nodes() - 1].data['hv']
logit = self.add_edge(torch.cat( logit = self.add_edge(torch.cat(
[graph_embed, src_embed], dim=1)) [graph_embed, src_embed], dim=1))
prob = torch.sigmoid(logit) prob = torch.sigmoid(logit)
if self.training: if self.training:
sample_log_prob = bernoulli_action_log_prob(logit, action) sample_log_prob = bernoulli_action_log_prob(logit, action)
self.log_prob.append(sample_log_prob) self.log_prob.append(sample_log_prob)
else: else:
action = Bernoulli(prob).sample().item() action = Bernoulli(prob).sample().item()
to_add_edge = bool(action == 0) to_add_edge = bool(action == 0)
return to_add_edge return to_add_edge
####################################################################################### #######################################################################################
# Action 3: choosing destination # Action 3: choosing destination
# ''''''''''''''''''''''''''''''''' # '''''''''''''''''''''''''''''''''
# #
# When action 2 returns True, we need to choose a destination for the # When action 2 returns True, we need to choose a destination for the
# latest node :math:`v`. # latest node :math:`v`.
# #
# For each possible destination :math:`u\in\{0, \cdots, v-1\}`, the # For each possible destination :math:`u\in\{0, \cdots, v-1\}`, the
# probability of choosing it is given by # probability of choosing it is given by
# #
# .. math:: # .. math::
# #
# \frac{\text{exp}(\textbf{W}_{\text{dest}}\text{concat}([\textbf{h}_{u}, \textbf{h}_{v}])+\textbf{b}_{\text{dest}})}{\sum_{i=0}^{v-1}\text{exp}(\textbf{W}_{\text{dest}}\text{concat}([\textbf{h}_{i}, \textbf{h}_{v}])+\textbf{b}_{\text{dest}})}\\ # \frac{\text{exp}(\textbf{W}_{\text{dest}}\text{concat}([\textbf{h}_{u}, \textbf{h}_{v}])+\textbf{b}_{\text{dest}})}{\sum_{i=0}^{v-1}\text{exp}(\textbf{W}_{\text{dest}}\text{concat}([\textbf{h}_{i}, \textbf{h}_{v}])+\textbf{b}_{\text{dest}})}\\
# #
from torch.distributions import Categorical from torch.distributions import Categorical
class ChooseDestAndUpdate(nn.Module): class ChooseDestAndUpdate(nn.Module):
def __init__(self, graph_prop_func, node_hidden_size): def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__() super(ChooseDestAndUpdate, self).__init__()
self.graph_op = {'prop': graph_prop_func} self.graph_op = {'prop': graph_prop_func}
self.choose_dest = nn.Linear(2 * node_hidden_size, 1) self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list): def _initialize_edge_repr(self, g, src_list, dest_list):
# For untyped edges, we only add 1 to indicate its existence. # For untyped edges, we only add 1 to indicate its existence.
# For multiple edge types, we can use a one hot representation # For multiple edge types, we can use a one hot representation
# or an embedding module. # or an embedding module.
edge_repr = torch.ones(len(src_list), 1) edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr g.edges[src_list, dest_list].data['he'] = edge_repr
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
def forward(self, g, dest): def forward(self, g, dest):
src = g.number_of_nodes() - 1 src = g.number_of_nodes() - 1
possible_dests = range(src) possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1) src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv'] possible_dests_embed = g.nodes[possible_dests].data['hv']
dests_scores = self.choose_dest( dests_scores = self.choose_dest(
torch.cat([possible_dests_embed, torch.cat([possible_dests_embed,
src_embed_expand], dim=1)).view(1, -1) src_embed_expand], dim=1)).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1) dests_probs = F.softmax(dests_scores, dim=1)
if not self.training: if not self.training:
dest = Categorical(dests_probs).sample().item() dest = Categorical(dests_probs).sample().item()
if not g.has_edge_between(src, dest): if not g.has_edge_between(src, dest):
# For undirected graphs, we add edges for both directions # For undirected graphs, we add edges for both directions
# so that we can perform graph propagation. # so that we can perform graph propagation.
src_list = [src, dest] src_list = [src, dest]
dest_list = [dest, src] dest_list = [dest, src]
g.add_edges(src_list, dest_list) g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list) self._initialize_edge_repr(g, src_list, dest_list)
self.graph_op['prop'](g) self.graph_op['prop'](g)
if self.training: if self.training:
if dests_probs.nelement() > 1: if dests_probs.nelement() > 1:
self.log_prob.append( self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1]) F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
####################################################################################### #######################################################################################
# Putting it together # Putting it together
# `````````````````````````` # ``````````````````````````
# #
# We are now ready to have a complete implementation of the model class. # We are now ready to have a complete implementation of the model class.
# #
class DGMG(DGMGSkeleton): class DGMG(DGMGSkeleton):
def __init__(self, v_max, node_hidden_size, def __init__(self, v_max, node_hidden_size,
num_prop_rounds): num_prop_rounds):
super(DGMG, self).__init__(v_max) super(DGMG, self).__init__(v_max)
# Graph embedding module # Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size) self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module # Graph propagation module
self.graph_prop = GraphProp(num_prop_rounds, self.graph_prop = GraphProp(num_prop_rounds,
node_hidden_size) node_hidden_size)
# Actions # Actions
self.add_node_agent = AddNode( self.add_node_agent = AddNode(
self.graph_embed, node_hidden_size) self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge( self.add_edge_agent = AddEdge(
self.graph_embed, node_hidden_size) self.graph_embed, node_hidden_size)
self.choose_dest_agent = ChooseDestAndUpdate( self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size) self.graph_prop, node_hidden_size)
# Forward functions # Forward functions
self.forward_train = partial(forward_train, self=self) self.forward_train = partial(forward_train, self=self)
self.forward_inference = partial(forward_inference, self=self) self.forward_inference = partial(forward_inference, self=self)
@property @property
def action_step(self): def action_step(self):
old_step_count = self.step_count old_step_count = self.step_count
self.step_count += 1 self.step_count += 1
return old_step_count return old_step_count
def prepare_for_train(self): def prepare_for_train(self):
self.step_count = 0 self.step_count = 0
self.add_node_agent.prepare_training() self.add_node_agent.prepare_training()
self.add_edge_agent.prepare_training() self.add_edge_agent.prepare_training()
self.choose_dest_agent.prepare_training() self.choose_dest_agent.prepare_training()
def add_node_and_update(self, a=None): def add_node_and_update(self, a=None):
"""Decide if to add a new node. """Decide if to add a new node.
If a new node should be added, update the graph.""" If a new node should be added, update the graph."""
return self.add_node_agent(self.g, a) return self.add_node_agent(self.g, a)
def add_edge_or_not(self, a=None): def add_edge_or_not(self, a=None):
"""Decide if a new edge should be added.""" """Decide if a new edge should be added."""
return self.add_edge_agent(self.g, a) return self.add_edge_agent(self.g, a)
def choose_dest_and_update(self, a=None): def choose_dest_and_update(self, a=None):
"""Choose destination and connect it to the latest node. """Choose destination and connect it to the latest node.
Add edges for both directions and update the graph.""" Add edges for both directions and update the graph."""
self.choose_dest_agent(self.g, a) self.choose_dest_agent(self.g, a)
def get_log_prob(self): def get_log_prob(self):
add_node_log_p = torch.cat(self.add_node_agent.log_prob).sum() add_node_log_p = torch.cat(self.add_node_agent.log_prob).sum()
add_edge_log_p = torch.cat(self.add_edge_agent.log_prob).sum() add_edge_log_p = torch.cat(self.add_edge_agent.log_prob).sum()
choose_dest_log_p = torch.cat(self.choose_dest_agent.log_prob).sum() choose_dest_log_p = torch.cat(self.choose_dest_agent.log_prob).sum()
return add_node_log_p + add_edge_log_p + choose_dest_log_p return add_node_log_p + add_edge_log_p + choose_dest_log_p
####################################################################################### #######################################################################################
# Below is an animation where a graph is generated on the fly # Below is an animation where a graph is generated on the fly
# after every 10 batches of training for the first 400 batches. One # after every 10 batches of training for the first 400 batches. One
# can see how our model improves over time and begins generating cycles. # can see how our model improves over time and begins generating cycles.
# #
# .. figure:: https://user-images.githubusercontent.com/19576924/48929291-60fe3880-ef22-11e8-832a-fbe56656559a.gif # .. figure:: https://user-images.githubusercontent.com/19576924/48929291-60fe3880-ef22-11e8-832a-fbe56656559a.gif
# :alt: # :alt:
# #
# For generative models, we can evaluate its performance by checking the percentage # For generative models, we can evaluate its performance by checking the percentage
# of valid graphs among the graphs it generates on the fly. # of valid graphs among the graphs it generates on the fly.
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
# Download a pre-trained model state dict for generating cycles with 10-20 nodes. # Download a pre-trained model state dict for generating cycles with 10-20 nodes.
state_dict = model_zoo.load_url('https://s3.us-east-2.amazonaws.com/dgl.ai/model/dgmg_cycles-5a0c40be.pth') state_dict = model_zoo.load_url('https://s3.us-east-2.amazonaws.com/dgl.ai/model/dgmg_cycles-5a0c40be.pth')
model = DGMG(v_max=20, node_hidden_size=16, num_prop_rounds=2) model = DGMG(v_max=20, node_hidden_size=16, num_prop_rounds=2)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.eval() model.eval()
def is_valid(g): def is_valid(g):
# Check if g is a cycle having 10-20 nodes. # Check if g is a cycle having 10-20 nodes.
def _get_previous(i, v_max): def _get_previous(i, v_max):
if i == 0: if i == 0:
return v_max return v_max
else: else:
return i - 1 return i - 1
def _get_next(i, v_max): def _get_next(i, v_max):
if i == v_max: if i == v_max:
return 0 return 0
else: else:
return i + 1 return i + 1
size = g.number_of_nodes() size = g.number_of_nodes()
if size < 10 or size > 20: if size < 10 or size > 20:
return False return False
for node in range(size): for node in range(size):
neighbors = g.successors(node) neighbors = g.successors(node)
if len(neighbors) != 2: if len(neighbors) != 2:
return False return False
if _get_previous(node, size - 1) not in neighbors: if _get_previous(node, size - 1) not in neighbors:
return False return False
if _get_next(node, size - 1) not in neighbors: if _get_next(node, size - 1) not in neighbors:
return False return False
return True return True
num_valid = 0 num_valid = 0
for i in range(100): for i in range(100):
g = model() g = model()
num_valid += is_valid(g) num_valid += is_valid(g)
del model del model
print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
####################################################################################### #######################################################################################
# For the complete implementation, see `dgl DGMG example # For the complete implementation, see `dgl DGMG example
# <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/dgmg>`__. # <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/dgmg>`__.
# #
# Batched Graph Generation # Batched Graph Generation
# --------------------------- # ---------------------------
# #
# Speeding up DGMG is hard since each graph can be generated with a # Speeding up DGMG is hard since each graph can be generated with a
# unique sequence of actions. One way to explore parallelism is to adopt # unique sequence of actions. One way to explore parallelism is to adopt
# asynchronous gradient descent with multiple processes. Each of them # asynchronous gradient descent with multiple processes. Each of them
# works on one graph at a time and the processes are loosely coordinated # works on one graph at a time and the processes are loosely coordinated
# by a parameter server. This is the approach that the authors adopted # by a parameter server. This is the approach that the authors adopted
# and we can also use. # and we can also use.
# #
# DGL explores parallelism in the message-passing framework, on top of # DGL explores parallelism in the message-passing framework, on top of
# the framework-provided tensor operation. The earlier tutorial already # the framework-provided tensor operation. The earlier tutorial already
# does that in the message propagation and graph embedding phases, but # does that in the message propagation and graph embedding phases, but
# only within one graph. For a batch of graphs, a for loop is then needed: # only within one graph. For a batch of graphs, a for loop is then needed:
# #
# :: # ::
# #
# for g in g_list: # for g in g_list:
# self.graph_prop(g) # self.graph_prop(g)
# #
# We can modify the code to work on a batch of graphs at once by replacing # We can modify the code to work on a batch of graphs at once by replacing
# these lines with the following. On CPU with a Mac machine, we instantly # these lines with the following. On CPU with a Mac machine, we instantly
# enjoy a 6~7x reduction for the graph propagation part. # enjoy a 6~7x reduction for the graph propagation part.
# :: # ::
# #
# bg = dgl.batch(g_list) # bg = dgl.batch(g_list)
# self.graph_prop(bg) # self.graph_prop(bg)
# g_list = dgl.unbatch(bg) # g_list = dgl.unbatch(bg)
# #
# We have already used this trick of calling ``dgl.batch`` in the # We have already used this trick of calling ``dgl.batch`` in the
# `Tree-LSTM tutorial # `Tree-LSTM tutorial
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__ # <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__
# , and it is worth explaining one more time why this is so. # , and it is worth explaining one more time why this is so.
# #
# By batching many small graphs, DGL internally maintains a large *container* # By batching many small graphs, DGL internally maintains a large *container*
# graph (``BatchedDGLGraph``) over which ``update_all`` propels message-passing # graph (``BatchedDGLGraph``) over which ``update_all`` propels message-passing
# on all the edges and nodes. # on all the edges and nodes.
# #
# With ``dgl.batch``, we merge ``g_{1}, ..., g_{N}`` into one single giant # With ``dgl.batch``, we merge ``g_{1}, ..., g_{N}`` into one single giant
# graph consisting of :math:`N` isolated small graphs. For example, if we # graph consisting of :math:`N` isolated small graphs. For example, if we
# have two graphs with adjacency matrices # have two graphs with adjacency matrices
# #
# :: # ::
# #
# [0, 1] # [0, 1]
# [1, 0] # [1, 0]
# #
# [0, 1, 0] # [0, 1, 0]
# [1, 0, 0] # [1, 0, 0]
# [0, 1, 0] # [0, 1, 0]
# #
# ``dgl.batch`` simply gives a graph whose adjacency matrix is # ``dgl.batch`` simply gives a graph whose adjacency matrix is
# #
# :: # ::
# #
# [0, 1, 0, 0, 0] # [0, 1, 0, 0, 0]
# [1, 0, 0, 0, 0] # [1, 0, 0, 0, 0]
# [0, 1, 0, 0, 0] # [0, 1, 0, 0, 0]
# [1, 0, 0, 0, 0] # [1, 0, 0, 0, 0]
# [0, 1, 0, 0, 0] # [0, 1, 0, 0, 0]
# #
# In DGL, the message function is defined on the edges, thus batching scales # In DGL, the message function is defined on the edges, thus batching scales
# the processing of edge user-defined functions (UDFs) linearly. # the processing of edge user-defined functions (UDFs) linearly.
# #
# The reduce UDFs (i.e ``dgmg_reduce``) works on nodes, and each of them may # The reduce UDFs (i.e ``dgmg_reduce``) works on nodes, and each of them may
# have different numbers of incoming edges. Using ``degree bucketing``, DGL # have different numbers of incoming edges. Using ``degree bucketing``, DGL
# internally groups nodes with the same in-degrees and calls reduce UDF once # internally groups nodes with the same in-degrees and calls reduce UDF once
# for each group. Thus, batching also reduces number of calls to these UDFs. # for each group. Thus, batching also reduces number of calls to these UDFs.
# #
# The modification of the node/edge features of a ``BatchedDGLGraph`` object # The modification of the node/edge features of a ``BatchedDGLGraph`` object
# does not take effect on the features of the original small graphs, so we # does not take effect on the features of the original small graphs, so we
# need to replace the old graph list with the new graph list # need to replace the old graph list with the new graph list
# ``g_list = dgl.unbatch(bg)``. # ``g_list = dgl.unbatch(bg)``.
# #
# The complete code to the batched version can also be found in the example. # The complete code to the batched version can also be found in the example.
# On our testbed, we get roughly 2x speed up comparing to the previous implementation # On our testbed, we get roughly 2x speed up comparing to the previous implementation
# #
.. _tutorials3-index:
Generative models
------------------------------
* **DGMG** `[paper] <https://arxiv.org/abs/1803.03324>`__ `[tutorial] <models/5_dgmg.html>`__
`[code] <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/dgmg>`__:
this model belongs to the important family that deals with structural
generation. DGMG is interesting because its state-machine approach is the most
general. It is also very challenging because, unlike Tree-LSTM, every sample
has a dynamic, probability-driven structure that is not available before
training. We are able to progressively leverage intra- and inter-graph
parallelism to steadily improve the performance.
* **JTNN** `[paper] <https://arxiv.org/abs/1802.04364>`__ `[code (wip)]`: unlike DGMG, this
paper generates molecular graphs using the framework of variational
auto-encoder. Perhaps more interesting is its approach to build structure
hierarchically, in the case of molecular, with junction tree as the middle
scaffolding.
...@@ -4,62 +4,62 @@ ...@@ -4,62 +4,62 @@
Capsule Network Tutorial Capsule Network Tutorial
=========================== ===========================
**Author**: Jinjing Zhou, `Jake **Author**: Jinjing Zhou, `Jake Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang, Jinyang Li
Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang
It is perhaps a little surprising that some of the more classical models can It is perhaps a little surprising that some of the more classical models
also be described in terms of graphs, offering a different perspective. can also be described in terms of graphs, offering a different
This tutorial describes how this is done for the `capsule network <http://arxiv.org/abs/1710.09829>`__. perspective. This tutorial describes how this can be done for the
`capsule network <http://arxiv.org/abs/1710.09829>`__.
""" """
####################################################################################### #######################################################################################
# Key ideas of Capsule # Key ideas of Capsule
# -------------------- # --------------------
# #
# There are two key ideas that the Capsule model offers. # The Capsule model offers two key ideas.
# #
# **Richer representations** In classic convolutional network, a scalar # **Richer representation** In classic convolutional networks, a scalar
# value represents the activation of a given feature. Instead, a capsule # value represents the activation of a given feature. By contrast, a
# outputs a vector, whose norm represents the probability of a feature, # capsule outputs a vector. The vector's length represents the probability
# and the orientation its properties. # of a feature being present. The vector's orientation represents the
# various properties of the feature (such as pose, deformation, texture
# etc.).
# #
# .. figure:: https://i.imgur.com/55Ovkdh.png # |image0|
# :alt:
# #
# **Dynamic routing** To generalize max-pooling, there is another # **Dynamic routing** The output of a capsule is preferentially sent to
# interesting proposed by the authors, as a representational more powerful # certain parents in the layer above based on how well the capsule's
# way to construct higher level feature from its low levels. Consider a # prediction agrees with that of a parent. Such dynamic
# capsule :math:`u_i`. The way :math:`u_i` is integrated to the next level # "routing-by-agreement" generalizes the static routing of max-pooling.
# capsules take two steps:
# #
# 1. :math:`u_i` projects differently to different higher level capsules # During training, routing is done iteratively; each iteration adjusts
# via a linear transformation: :math:`\hat{u}_{j|i} = W_{ij}u_i`. # "routing weights" between capsules based on their observed agreements,
# 2. :math:`\hat{u}_{j|i}` routes to the higher level capsules by # in a manner similar to a k-means algorithm or `competitive
# spreading itself with a weighted sum, and the weight is dynamically # learning <https://en.wikipedia.org/wiki/Competitive_learning>`__.
# determined by iteratively modify the and checking against the
# "consistency" between :math:`\hat{u}_{j|i}` and :math:`v_j`, for any
# :math:`v_j`. Note that this is similar to a k-means algorithm or
# `competive
# learning <https://en.wikipedia.org/wiki/Competitive_learning>`__ in
# spirit. At the end of iterations, :math:`v_j` now integrates the
# lower level capsules.
# #
# The full algorithm is the following: |image0| # In this tutorial, we show how capsule's dynamic routing algorithm can be
# # naturally expressed as a graph algorithm. Our implementation is adapted
# The dynamic routing step can be naturally expressed as a graph # from `Cedric
# algorithm. This is the focus of this tutorial. Our implementation is
# adapted from `Cedric
# Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing # Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing
# only the routing layer, and achieving similar speed and accuracy. # only the routing layer. Our version achieves similar speed and accuracy.
# #
# Model Implementation # Model Implementation
# ----------------------------------- # ----------------------
# Step 1: Setup and Graph Initialiation # Step 1: Setup and Graph Initialization
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The connectivity between two layers of capsules form a directed,
# bipartite graph, as shown in the Figure below.
#
# |image1|
# #
# The below figure shows the directed bipartitie graph built for capsules # Each node :math:`j` is associated with feature :math:`v_j`,
# network. We denote :math:`b_{ij}`, :math:`\hat{u}_{j|i}` as edge # representing its capsule’s output. Each edge is associated with
# features and :math:`v_j` as node features. |image1| # features :math:`b_{ij}` and :math:`\hat{u}_{j|i}`. :math:`b_{ij}`
# determines routing weights, and :math:`\hat{u}_{j|i}` represents the
# prediction of capsule :math:`i` for :math:`j`.
# #
# Here's how we set up the graph and initialize node and edge features.
import torch.nn as nn import torch.nn as nn
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
...@@ -88,32 +88,33 @@ def init_graph(in_nodes, out_nodes, f_size): ...@@ -88,32 +88,33 @@ def init_graph(in_nodes, out_nodes, f_size):
######################################################################################### #########################################################################################
# Step 2: Define message passing functions # Step 2: Define message passing functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Recall the following steps, and they are implemented in the class
# ``DGLRoutingLayer`` as the followings:
# #
# 1. Normalize over out edges # This is the pseudo code for Capsule's routing algorithm as given in the
# paper:
#
# |image2|
# We implement pseudo code lines 4-7 in the class `DGLRoutingLayer` as the following steps:
#
# 1. Calculate coupling coefficients:
# #
# - Softmax over all out-edge of in-capsules # - Coefficients are the softmax over all out-edge of in-capsules:
# :math:`\textbf{c}_i = \text{softmax}(\textbf{b}_i)`. # :math:`\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})`.
# #
# 2. Weighted sum over all in-capsules # 2. Calculate weighted sum over all in-capsules:
# #
# - Out-capsules equals weighted sum of in-capsules # - Output of a capsule is equal to the weighted sum of its in-capsules
# :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}` # :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
# #
# 3. Squash Operation # 3. Squash outputs:
# #
# - Squashing function is to ensure that short capsule vectors get # - Squash the length of a capsule's output vector to range (0,1), so it can represent the probability (of some feature being present).
# shrunk to almost zero length while the long capsule vectors get
# shrunk to a length slightly below 1. Its norm is expected to
# represents probabilities at some levels.
# - :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}` # - :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
# #
# 4. Update weights by agreement # 4. Update weights by the amount of agreement:
# #
# - :math:`\hat{u}_{j|i}\cdot v_j` can be considered as agreement # - The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update
# between current capsule and updated capsule,
# :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j` # :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
class DGLRoutingLayer(nn.Module): class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size): def __init__(self, in_nodes, out_nodes, f_size):
super(DGLRoutingLayer, self).__init__() super(DGLRoutingLayer, self).__init__()
...@@ -172,9 +173,9 @@ u_hat = th.randn(in_nodes * out_nodes, f_size) ...@@ -172,9 +173,9 @@ u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size) routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)
############################################################################################################ ############################################################################################################
# We can visualize the behavior by monitoring the entropy of outgoing # We can visualize a capsule network's behavior by monitoring the entropy
# weights, they should start high and then drop, as the assignment # of coupling coefficients. They should start high and then drop, as the
# gradually concentrate: # weights gradually concentrate on fewer edges:
entropy_list = [] entropy_list = []
dist_list = [] dist_list = []
...@@ -193,9 +194,9 @@ plt.xlabel("Number of Routing") ...@@ -193,9 +194,9 @@ plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list))) plt.xticks(np.arange(len(entropy_list)))
plt.close() plt.close()
############################################################################################################ ############################################################################################################
# |image3|
# #
# .. figure:: https://i.imgur.com/dMvu7p3.png # Alternatively, we can also watch the evolution of histograms:
# :alt:
import seaborn as sns import seaborn as sns
import matplotlib.animation as animation import matplotlib.animation as animation
...@@ -216,8 +217,10 @@ ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), inter ...@@ -216,8 +217,10 @@ ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), inter
plt.close() plt.close()
############################################################################################################ ############################################################################################################
# Alternatively, we can also watch the evolution of histograms: |image2| # |image4|
# Or monitor the how lower level capcules gradually attach to one of the higher level ones: #
# Or monitor the how lower level capsules gradually attach to one of the
# higher level ones:
import networkx as nx import networkx as nx
from networkx.algorithms import bipartite from networkx.algorithms import bipartite
...@@ -251,14 +254,16 @@ ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), inte ...@@ -251,14 +254,16 @@ ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), inte
plt.close() plt.close()
############################################################################################################ ############################################################################################################
# |image3| # |image5|
# #
# The full code of this visulization is provided at # The full code of this visualization is provided at
# `link <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__; the complete # `link <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__; the complete
# code that trains on MNIST is at `link <https://github.com/jermainewang/dgl/tree/tutorial/examples/pytorch/capsule>`__. # code that trains on MNIST is at `link <https://github.com/jermainewang/dgl/tree/tutorial/examples/pytorch/capsule>`__.
# #
# .. |image0| image:: https://i.imgur.com/mv1W9Rv.png # .. |image0| image:: https://i.imgur.com/55Ovkdh.png
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png # .. |image1| image:: https://i.imgur.com/9tc6GLl.png
# .. |image2| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif # .. |image2| image:: https://i.imgur.com/mv1W9Rv.png
# .. |image3| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif # .. |image3| image:: https://i.imgur.com/dMvu7p3.png
# # .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif
.. _tutorials4-index:
Old (new) wines in new bottle
-----------------------------
* **Capsule** `[paper] <https://arxiv.org/abs/1710.09829>`__ `[tutorial] <models/2_capsule.html>`__
`[code] <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/capsule>`__: this new
computer vision model has two key ideas -- enhancing the feature representation
in a vector form (instead of a scalar) called *capsule*, and replacing
maxpooling with dynamic routing. The idea of dynamic routing is to integrate a
lower level capsule to one (or several) of a higher level one with
non-parametric message-passing. We show how the later can be nicely implemented
with DGL APIs.
* **Transformer** `[paper] <https://arxiv.org/abs/1706.03762>`__ `[tutorial (wip)]` `[code (wip)]` and
**Universal Transformer** `[paper] <https://arxiv.org/abs/1807.03819>`__ `[tutorial (wip)]`
`[code (wip)]`: these
two models replace RNN with several layers of multi-head attention to encode
and discover structures among tokens of a sentence. These attention mechanisms
can similarly formulated as graph operations with message-passing.
Graph-based Neural Network Models Graph-based Neural Network Models
================================= =================================
We developed DGL with a broad range of applications in mind. Building We developed DGL with a broad range of applications in mind. Building
state-of-art models forces us to think hard on the most common and useful APIs, state-of-art models forces us to think hard on the most common and useful APIs,
learn the hard lessons, and push the system design. learn the hard lessons, and push the system design.
We have prototyped altogether 10 different models, all of them are ready to run We have prototyped altogether 10 different models, all of them are ready to run
out-of-box and some of them are very new graph-based algorithms. In most of the out-of-box and some of them are very new graph-based algorithms. In most of the
cases, they demonstrate the performance, flexibility, and expressiveness of cases, they demonstrate the performance, flexibility, and expressiveness of
DGL. For where we still fall in short, these exercises point to future DGL. For where we still fall in short, these exercises point to future
directions. directions.
We categorize the models below, providing links to the original code and We categorize the models below, providing links to the original code and
tutorial when appropriate. As will become apparent, these models stress the use tutorial when appropriate. As will become apparent, these models stress the use
of different DGL APIs. of different DGL APIs.
Graph Neural Network and its variant
------------------------------------
* **GCN** `[paper] <https://arxiv.org/abs/1609.02907>`__ `[tutorial] <models/1_gcn.html>`__
`[code] <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/gcn/gcn.py>`__:
this is the vanilla GCN. The tutorial covers the basic uses of DGL APIs.
* **GAT** `[paper] <https://arxiv.org/abs/1710.10903>`__
`[code] <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/gat/gat.py>`__:
the key extension of GAT w.r.t vanilla GCN is deploying multi-head attention
among neighborhood of a node, thus greatly enhances the capacity and
expressiveness of the model.
* **R-GCN** `[paper] <https://arxiv.org/abs/1703.06103>`__ `[tutorial] <models/4_rgcn.html>`__
[code (wip)]: the key
difference of RGNN is to allow multi-edges among two entities of a graph, and
edges with distinct relationships are encoded differently. This is an
interesting extension of GCN that can have a lot of applications of its own.
* **LGNN** `[paper] <https://arxiv.org/abs/1705.08415>`__ `[tutorial (wip)]` `[code (wip)]`:
this model focuses on community detection by inspecting graph structures. It
uses representations of both the orignal graph and its line-graph companion. In
addition to demonstrate how an algorithm can harness multiple graphs, our
implementation shows how one can judiciously mix vanilla tensor operation,
sparse-matrix tensor operations, along with message-passing with DGL.
* **SSE** `[paper] <http://proceedings.mlr.press/v80/dai18a/dai18a.pdf>`__ `[tutorial (wip)]`
`[code] <https://github.com/jermainewang/dgl/blob/master/examples/mxnet/sse/sse_batch.py>`__:
the emphasize here is *giant* graph that cannot fit comfortably on one GPU
card. SSE is an example to illustrate the co-design of both algrithm and
system: sampling to guarantee asymptotic covergence while lowering the
complexity, and batching across samples for maximum parallelism.
Dealing with many small graphs
------------------------------
* **Tree-LSTM** `[paper] <https://arxiv.org/abs/1503.00075>`__ `[tutorial] <models/3_tree-lstm.html>`__
`[code] <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/tree_lstm/tree_lstm.py>`__:
sentences of natural languages have inherent structures, which are thrown away
by treating them simply as sequences. Tree-LSTM is a powerful model that learns
the representation by leveraging prior syntactic structures (e.g. parse-tree).
The challenge to train it well is that simply by padding a sentence to the
maximum length no longer works, since trees of different sentences have
different sizes and topologies. DGL solves this problem by throwing the trees
into a bigger "container" graph, and use message-passing to explore maximum
parallelism. The key API we use is batching.
Generative models
------------------------------
* **DGMG** `[paper] <https://arxiv.org/abs/1803.03324>`__ `[tutorial] <models/5_dgmg.html>`__
`[code] <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/dgmg>`__:
this model belongs to the important family that deals with structural
generation. DGMG is interesting because its state-machine approach is the most
general. It is also very challenging because, unlike Tree-LSTM, every sample
has a dynamic, probability-driven structure that is not available before
training. We are able to progressively leverage intra- and inter-graph
parallelism to steadily improve the performance.
* **JTNN** `[paper] <https://arxiv.org/abs/1802.04364>`__ `[code (wip)]`: unlike DGMG, this
paper generates molecular graphs using the framework of variational
auto-encoder. Perhaps more interesting is its approach to build structure
hierarchically, in the case of molecular, with junction tree as the middle
scaffolding.
Old (new) wines in new bottle
-----------------------------
* **Capsule** `[paper] <https://arxiv.org/abs/1710.09829>`__ `[tutorial] <models/2_capsule.html>`__
`[code] <https://github.com/jermainewang/dgl/tree/master/examples/pytorch/capsule>`__: this new
computer vision model has two key ideas -- enhancing the feature representation
in a vector form (instead of a scalar) called *capsule*, and replacing
maxpooling with dynamic routing. The idea of dynamic routing is to integrate a
lower level capsule to one (or several) of a higher level one with
non-parametric message-passing. We show how the later can be nicely implemented
with DGL APIs.
* **Transformer** `[paper] <https://arxiv.org/abs/1706.03762>`__ `[tutorial (wip)]` `[code (wip)]` and
**Universal Transformer** `[paper] <https://arxiv.org/abs/1807.03819>`__ `[tutorial (wip)]`
`[code (wip)]`: these
two models replace RNN with several layers of multi-head attention to encode
and discover structures among tokens of a sentence. These attention mechanisms
can similarly formulated as graph operations with message-passing.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment