Unverified Commit 90d86fcb authored by zhjwy9343's avatar zhjwy9343 Committed by GitHub
Browse files

Master refactor split chapter2n3 (#2215)



* [Feature] Add full graph training with dgl built-in dataset.

* [Feature] Add full graph training with dgl built-in dataset.

* [Feature] Add full graph training with dgl built-in dataset.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Bug] fix model to cuda.

* [Feature] Add test loss and accuracy

* [Feature] Add test loss and accuracy

* [Feature] Add test loss and accuracy

* [Feature] Add test loss and accuracy

* [Feature] Add test loss and accuracy

* [Feature] Add test loss and accuracy

* [Fix] Add random

* [Bug] Fix batch norm error

* [Doc] Test with CN in Sphinx

* [Doc] Test with CN in Sphinx

* [Doc] Remove the test CN docs.

* [Feature] Add input embedding layer

* [Feature] Add input embedding layer

* [Feature] Add input embedding layer

* [Feature] Add input embedding layer

* [Feature] Add input embedding layer

* [Feature] Add input embedding layer

* [Feature] Add input embedding layer

* [Feature] Add input embedding layer

* [Feature] Add input embedding layer

* [Doc] fill readme with new performance results

* [Doc] Add Chinese User Guide, graph and 1.5

* [Doc] Add Chinese User Guide, graph and 1.5

* [Doc] Refactor and split chapter 4

* [Fix] Remove CompGCN example codes

* [Doc] Add chapter 2 refactor and split

* [Fix] code format of savenload

* [Doc] Split chapter 3

* [Doc] Add introduction phrase of chapter 2

* [Doc] Add introduction phrase of chapter 2

* [Doc] Add introduction phrase of chapter 3

* Fix

* Update chapter 2

* Update chapter 3

* Update chapter 4
Co-authored-by: default avatarmufeili <mufeili1996@gmail.com>
parent 49e96970
.. _guide-data-pipeline-dataset:
4.1 DGLDataset class
--------------------
:class:`~dgl.data.DGLDataset` is the base class for processing, loading and saving
graph datasets defined in :ref:`apidata`. It implements the basic pipeline
for processing graph data. The following flow chart shows how the
pipeline works.
To process a graph dataset located in a remote server or local disk, one can
define a class, say ``MyDataset``, inheriting from :class:`dgl.data.DGLDataset`. The
template of ``MyDataset`` is as follows.
.. figure:: https://data.dgl.ai/asset/image/userguide_data_flow.png
:align: center
Flow chart for graph data input pipeline defined in class DGLDataset.
.. code::
from dgl.data import DGLDataset
class MyDataset(DGLDataset):
""" Template for customizing graph datasets in DGL.
Parameters
----------
url : str
URL to download the raw dataset
raw_dir : str
Specifying the directory that will store the
downloaded data or the directory that
already stores the input data.
Default: ~/.dgl/
save_dir : str
Directory to save the processed dataset.
Default: the value of `raw_dir`
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information
"""
def __init__(self,
url=None,
raw_dir=None,
save_dir=None,
force_reload=False,
verbose=False):
super(MyDataset, self).__init__(name='dataset_name',
url=url,
raw_dir=raw_dir,
save_dir=save_dir,
force_reload=force_reload,
verbose=verbose)
def download(self):
# download raw data to local disk
pass
def process(self):
# process raw data to graphs, labels, splitting masks
pass
def __getitem__(self, idx):
# get one example by index
pass
def __len__(self):
# number of data examples
pass
def save(self):
# save processed data to directory `self.save_path`
pass
def load(self):
# load processed data from directory `self.save_path`
pass
def has_cache(self):
# check whether there are processed data in `self.save_path`
pass
:class:`~dgl.data.DGLDataset` class has abstract functions ``process()``,
``__getitem__(idx)`` and ``__len__()`` that must be implemented in the
subclass. DGL also recommends implementing saving and loading as well,
since they can save significant time for processing large datasets, and
there are several APIs making it easy (see :ref:`guide-data-pipeline-savenload`).
Note that the purpose of :class:`~dgl.data.DGLDataset` is to provide a standard and
convenient way to load graph data. One can store graphs, features,
labels, masks and basic information about the dataset, such as number of
classes, number of labels, etc. Operations such as sampling, partition
or feature normalization are done outside of the :class:`~dgl.data.DGLDataset`
subclass.
The rest of this chapter shows the best practices to implement the
functions in the pipeline.
.. _guide-data-pipeline-download:
4.2 Download raw data (optional)
--------------------------------
If a dataset is already in local disk, make sure it’s in directory
``raw_dir``. If one wants to run the code anywhere without bothering to
download and move data to the right directory, one can do it
automatically by implementing function ``download()``.
If the dataset is a zip file, make ``MyDataset`` inherit from
:class:`dgl.data.DGLBuiltinDataset` class, which handles the zip file extraction for us. Otherwise,
one needs to implement ``download()`` like in :class:`~dgl.data.QM7bDataset`:
.. code::
import os
from dgl.data.utils import download
def download(self):
# path to store the file
file_path = os.path.join(self.raw_dir, self.name + '.mat')
# download file
download(self.url, path=file_path)
The above code downloads a .mat file to directory ``self.raw_dir``. If
the file is a .gz, .tar, .tar.gz or .tgz file, use :func:`~dgl.data.utils.extract_archive`
function to extract. The following code shows how to download a .gz file
in :class:`~dgl.data.BitcoinOTCDataset`:
.. code::
from dgl.data.utils import download, check_sha1
def download(self):
# path to store the file
# make sure to use the same suffix as the original file name's
gz_file_path = os.path.join(self.raw_dir, self.name + '.csv.gz')
# download file
download(self.url, path=gz_file_path)
# check SHA-1
if not check_sha1(gz_file_path, self._sha1_str):
raise UserWarning('File {} is downloaded but the content hash does not match.'
'The repo may be outdated or download may be incomplete. '
'Otherwise you can create an issue for it.'.format(self.name + '.csv.gz'))
# extract file to directory `self.name` under `self.raw_dir`
self._extract_gz(gz_file_path, self.raw_path)
The above code will extract the file into directory ``self.name`` under
``self.raw_dir``. If the class inherits from :class:`dgl.data.DGLBuiltinDataset`
to handle zip file, it will extract the file into directory ``self.name``
as well.
Optionally, one can check SHA-1 string of the downloaded file as the
example above does, in case the author changed the file in the remote
server some day.
.. _guide-data-pipeline-loadogb:
4.5 Loading OGB datasets using ``ogb`` package
----------------------------------------------
`Open Graph Benchmark (OGB) <https://ogb.stanford.edu/docs/home/>`__ is
a collection of benchmark datasets. The official OGB package
`ogb <https://github.com/snap-stanford/ogb>`__ provides APIs for
downloading and processing OGB datasets into :class:`dgl.data.DGLGraph` objects. The section
introduce their basic usage here.
First install ogb package using pip:
.. code::
pip install ogb
The following code shows how to load datasets for *Graph Property
Prediction* tasks.
.. code::
# Load Graph Property Prediction datasets in OGB
import dgl
import torch
from ogb.graphproppred import DglGraphPropPredDataset
from torch.utils.data import DataLoader
def _collate_fn(batch):
# batch is a list of tuple (graph, label)
graphs = [e[0] for e in batch]
g = dgl.batch(graphs)
labels = [e[1] for e in batch]
labels = torch.stack(labels, 0)
return g, labels
# load dataset
dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()
# dataloader
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=_collate_fn)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=_collate_fn)
Loading *Node Property Prediction* datasets is similar, but note that
there is only one graph object in this kind of dataset.
.. code::
# Load Node Property Prediction datasets in OGB
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset(name='ogbn-proteins')
split_idx = dataset.get_idx_split()
# there is only one graph in Node Property Prediction datasets
g, labels = dataset[0]
# get split labels
train_label = dataset.labels[split_idx['train']]
valid_label = dataset.labels[split_idx['valid']]
test_label = dataset.labels[split_idx['test']]
*Link Property Prediction* datasets also contain one graph per dataset:
.. code::
# Load Link Property Prediction datasets in OGB
from ogb.linkproppred import DglLinkPropPredDataset
dataset = DglLinkPropPredDataset(name='ogbl-ppa')
split_edge = dataset.get_edge_split()
graph = dataset[0]
print(split_edge['train'].keys())
print(split_edge['valid'].keys())
print(split_edge['test'].keys())
.. _guide-data-pipeline-process:
4.3 Process data
----------------
One can implement the data processing code in function ``process()``, and it
assumes that the raw data is located in ``self.raw_dir`` already. There
are typically three types of tasks in machine learning on graphs: graph
classification, node classification, and link prediction. This section will show
how to process datasets related to these tasks.
The section focuses on the standard way to process graphs, features and masks.
It will use builtin datasets as examples and skip the implementations
for building graphs from files, but add links to the detailed
implementations. Please refer to :ref:`guide-graph-external` to see a
complete guide on how to build graphs from external sources.
Processing Graph Classification datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Graph classification datasets are almost the same as most datasets in
typical machine learning tasks, where mini-batch training is used. So one can
process the raw data to a list of :class:`dgl.DGLGraph` objects and a list of
label tensors. In addition, if the raw data has been split into
several files, one can add a parameter ``split`` to load specific part of
the data.
Take :class:`~dgl.data.QM7bDataset` as example:
.. code::
from dgl.data import DGLDataset
class QM7bDataset(DGLDataset):
_url = 'http://deepchem.io.s3-website-us-west-1.amazonaws.com/' \
'datasets/qm7b.mat'
_sha1_str = '4102c744bb9d6fd7b40ac67a300e49cd87e28392'
def __init__(self, raw_dir=None, force_reload=False, verbose=False):
super(QM7bDataset, self).__init__(name='qm7b',
url=self._url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
mat_path = self.raw_path + '.mat'
# process data to a list of graphs and a list of labels
self.graphs, self.label = self._load_graph(mat_path)
def __getitem__(self, idx):
""" Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
(dgl.DGLGraph, Tensor)
"""
return self.graphs[idx], self.label[idx]
def __len__(self):
"""Number of graphs in the dataset"""
return len(self.graphs)
In ``process()``, the raw data is processed to a list of graphs and a
list of labels. One must implement ``__getitem__(idx)`` and ``__len__()``
for iteration. DGL recommends making ``__getitem__(idx)`` return a
tuple ``(graph, label)`` as above. Please check the `QM7bDataset source
code <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/qm7b.html#QM7bDataset>`__
for details of ``self._load_graph()`` and ``__getitem__``.
One can also add properties to the class to indicate some useful
information of the dataset. In :class:`~dgl.data.QM7bDataset`, one can add a property
``num_labels`` to indicate the total number of prediction tasks in this
multi-task dataset:
.. code::
@property
def num_labels(self):
"""Number of labels for each graph, i.e. number of prediction tasks."""
return 14
After all these coding, one can finally use :class:`~dgl.data.QM7bDataset` as
follows:
.. code::
import dgl
import torch
from torch.utils.data import DataLoader
# load data
dataset = QM7bDataset()
num_labels = dataset.num_labels
# create collate_fn
def _collate_fn(batch):
graphs, labels = batch
g = dgl.batch(graphs)
labels = torch.tensor(labels, dtype=torch.long)
return g, labels
# create dataloaders
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn)
# training
for epoch in range(100):
for g, labels in dataloader:
# your training code here
pass
A complete guide for training graph classification models can be found
in :ref:`guide-training-graph-classification`.
For more examples of graph classification datasets, please refer to DGL's builtin graph classification
datasets:
* :ref:`gindataset`
* :ref:`minigcdataset`
* :ref:`qm7bdata`
* :ref:`tudata`
Processing Node Classification datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Different from graph classification, node classification is typically on
a single graph. As such, splits of the dataset are on the nodes of the
graph. DGL recommends using node masks to specify the splits. The section uses
builtin dataset `CitationGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__ as an example:
.. code::
from dgl.data import DGLBuiltinDataset
from dgl.data.utils import _get_dgl_url, generate_mask_tensor
class CitationGraphDataset(DGLBuiltinDataset):
_urls = {
'cora_v2' : 'dataset/cora_v2.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
def __init__(self, name, raw_dir=None, force_reload=False, verbose=True):
assert name.lower() in ['cora', 'citeseer', 'pubmed']
if name.lower() == 'cora':
name = 'cora_v2'
url = _get_dgl_url(self._urls[name])
super(CitationGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# Skip some processing code
# === data processing skipped ===
# build graph
g = dgl.graph(graph)
# splitting masks
g.ndata['train_mask'] = generate_mask_tensor(train_mask)
g.ndata['val_mask'] = generate_mask_tensor(val_mask)
g.ndata['test_mask'] = generate_mask_tensor(test_mask)
# node labels
g.ndata['label'] = torch.tensor(labels)
# node features
g.ndata['feat'] = torch.tensor(_preprocess_features(features),
dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
self._labels = labels
self._g = g
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
return 1
For brevity, this section skips some code in ``process()`` to highlight the key
part for processing node classification dataset: splitting masks. Node
features and node labels are stored in ``g.ndata``. For detailed
implementation, please refer to `CitationGraphDataset source
code <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/citation_graph.html#CitationGraphDataset>`__.
Note that the implementations of ``__getitem__(idx)`` and
``__len__()`` are changed as well, since there is often only one graph
for node classification tasks. The masks are ``bool tensors`` in PyTorch
and TensorFlow, and ``float tensors`` in MXNet.
The section uses a subclass of ``CitationGraphDataset``, :class:`dgl.data.CiteseerGraphDataset`,
to show the usage of it:
.. code::
# load data
dataset = CiteseerGraphDataset(raw_dir='')
graph = dataset[0]
# get split masks
train_mask = graph.ndata['train_mask']
val_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
# get node features
feats = graph.ndata['feat']
# get labels
labels = graph.ndata['label']
A complete guide for training node classification models can be found in
:ref:`guide-training-node-classification`.
For more examples of node classification datasets, please refer to DGL's
builtin datasets:
* :ref:`citationdata`
* :ref:`corafulldata`
* :ref:`amazoncobuydata`
* :ref:`coauthordata`
* :ref:`karateclubdata`
* :ref:`ppidata`
* :ref:`redditdata`
* :ref:`sbmdata`
* :ref:`sstdata`
* :ref:`rdfdata`
Processing dataset for Link Prediction datasets
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The processing of link prediction datasets is similar to that for node
classification’s, there is often one graph in the dataset.
The section uses builtin dataset
`KnowledgeGraphDataset <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__
as an example, and still skips the detailed data processing code to
highlight the key part for processing link prediction datasets:
.. code::
# Example for creating Link Prediction datasets
class KnowledgeGraphDataset(DGLBuiltinDataset):
def __init__(self, name, reverse=True, raw_dir=None, force_reload=False, verbose=True):
self._name = name
self.reverse = reverse
url = _get_dgl_url('dataset/') + '{}.tgz'.format(name)
super(KnowledgeGraphDataset, self).__init__(name,
url=url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
# Skip some processing code
# === data processing skipped ===
# splitting mask
g.edata['train_mask'] = train_mask
g.edata['val_mask'] = val_mask
g.edata['test_mask'] = test_mask
# edge type
g.edata['etype'] = etype
# node type
g.ndata['ntype'] = ntype
self._g = g
def __getitem__(self, idx):
assert idx == 0, "This dataset has only one graph"
return self._g
def __len__(self):
return 1
As shown in the code, it adds splitting masks into ``edata`` field of the
graph. Check `KnowledgeGraphDataset source
code <https://docs.dgl.ai/en/0.5.x/_modules/dgl/data/knowledge_graph.html#KnowledgeGraphDataset>`__
to see the complete code. The following code uses a subclass of ``KnowledgeGraphDataset``,
:class:`dgl.data.FB15k237Dataset`, to show the usage of it:
.. code::
from dgl.data import FB15k237Dataset
# load data
dataset = FB15k237Dataset()
graph = dataset[0]
# get training mask
train_mask = graph.edata['train_mask']
train_idx = torch.nonzero(train_mask).squeeze()
src, dst = graph.edges(train_idx)
# get edge types in training set
rel = graph.edata['etype'][train_idx]
A complete guide for training link prediction models can be found in
:ref:`guide-training-link-prediction`.
For more examples of link prediction datasets, please refer to DGL's
builtin datasets:
* :ref:`kgdata`
* :ref:`bitcoinotcdata`
.. _guide-data-pipeline-savenload:
4.4 Save and load data
----------------------
DGL recommends implementing saving and loading functions to cache the
processed data in local disk. This saves a lot of data processing time
in most cases. DGL provides four functions to make things simple:
- :func:`dgl.save_graphs` and :func:`dgl.load_graphs`: save/load DGLGraph objects and labels to/from local disk.
- :func:`dgl.data.utils.save_info` and :func:`dgl.data.utils.load_info`: save/load useful information of the dataset (python ``dict`` object) to/from local disk.
The following example shows how to save and load a list of graphs and
dataset information.
.. code::
import os
from dgl import save_graphs, load_graphs
from dgl.data.utils import makedirs, save_info, load_info
def save(self):
# save graphs and labels
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
save_graphs(graph_path, self.graphs, {'labels': self.labels})
# save other information in python dict
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
save_info(info_path, {'num_classes': self.num_classes})
def load(self):
# load processed data from directory `self.save_path`
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
self.graphs, label_dict = load_graphs(graph_path)
self.labels = label_dict['labels']
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
self.num_classes = load_info(info_path)['num_classes']
def has_cache(self):
# check whether there are processed data in `self.save_path`
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
info_path = os.path.join(self.save_path, self.mode + '_info.pkl')
return os.path.exists(graph_path) and os.path.exists(info_path)
Note that there are cases not suitable to save processed data. For
example, in the builtin dataset :class:`~dgl.data.GDELTDataset`,
the processed data is quite large, so it’s more effective to process
each data example in ``__getitem__(idx)``.
.. code::
print(split_edge['valid'].keys())
print(split_edge['test'].keys())
This diff is collapsed.
.. _guide-message-passing-api:
2.1 Built-in Functions and Message Passing APIs
-----------------------------------------------
In DGL, **message function** takes a single argument ``edges``,
which is an :class:`~dgl.udf.EdgeBatch` instance. During message passing,
DGL generates it internally to represent a batch of edges. It has three
members ``src``, ``dst`` and ``data`` to access features of source nodes,
destination nodes, and edges, respectively.
**reduce function** takes a single argument ``nodes``, which is a
:class:`~dgl.udf.NodeBatch` instance. During message passing,
DGL generates it internally to represent a batch of nodes. It has member
``mailbox`` to access the messages received for the nodes in the batch.
Some of the most common reduce operations include ``sum``, ``max``, ``min``, etc.
**update function** takes a single argument ``nodes`` as described above.
This function operates on the aggregation result from ``reduce function``, typically
combining it with a node’s original feature at the the last step and saving the result
as a node feature.
DGL has implemented commonly used message functions and reduce functions
as **built-in** in the namespace ``dgl.function``. In general, DGL
suggests using built-in functions **whenever possible** since they are
heavily optimized and automatically handle dimension broadcasting.
If your message passing functions cannot be implemented with built-ins,
you can implement user-defined message/reduce function (aka. **UDF**).
Built-in message functions can be unary or binary. DGL supports ``copy``
for unary. For binary funcs, DGL supports ``add``, ``sub``, ``mul``, ``div``,
``dot``. The naming convention for message built-in funcs is that ``u``
represents ``src`` nodes, ``v`` represents ``dst`` nodes, and ``e`` represents ``edges``.
The parameters for those functions are strings indicating the input and output field names for
the corresponding nodes and edges. The list of supported built-in functions
can be found in :ref:`api-built-in`. For example, to add the ``hu`` feature from src
nodes and ``hv`` feature from dst nodes then save the result on the edge
at ``he`` field, one can use built-in function ``dgl.function.u_add_v('hu', 'hv', 'he')``.
This is equivalent to the Message UDF:
.. code::
def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
Built-in reduce functions support operations ``sum``, ``max``, ``min``,
and ``mean``. Reduce functions usually have two parameters, one
for field name in ``mailbox``, one for field name in node features, both
are strings. For example, ``dgl.function.sum('m', 'h')`` is equivalent
to the Reduce UDF that sums up the message ``m``:
.. code::
import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
It is also possible to invoke only edge-wise computation by :meth:`~dgl.DGLGraph.apply_edges`
without invoking message passing. :meth:`~dgl.DGLGraph.apply_edges` takes a message function
for parameter and by default updates the features of all edges. For example:
.. code::
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
For message passing, :meth:`~dgl.DGLGraph.update_all` is a high-level
API that merges message generation, message aggregation and node update
in a single call, which leaves room for optimization as a whole.
The parameters for :meth:`~dgl.DGLGraph.update_all` are a message function, a
reduce function and an update function. One can call update function outside of
``update_all`` and not specify it in invoking :meth:`~dgl.DGLGraph.update_all`.
DGL recommends this approach since the update function can usually be
written as pure tensor operations to make the code concise. For
example:
.. code::
def updata_all_example(graph):
# store the result in graph.ndata['ft']
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# Call update function outside of update_all
final_ft = graph.ndata['ft'] * 2
return final_ft
This call will generate the messages ``m`` by multiply src node features
``ft`` and edge features ``a``, sum up the messages ``m`` to update node
features ``ft``, and finally multiply ``ft`` by 2 to get the result
``final_ft``. After the call, DGL will clean the intermediate messages ``m``.
The math formula for the above function is:
.. math:: {final\_ft}_i = 2 * \sum_{j\in\mathcal{N}(i)} ({ft}_j * a_{ij})
.. _guide-message-passing-edge:
2.4 Apply Edge Weight In Message Passing
----------------------------------------
A commonly seen practice in GNN modeling is to apply edge weight on the
message before message aggregation, for examples, in
`GAT <https://arxiv.org/pdf/1710.10903.pdf>`__ and some `GCN
variants <https://arxiv.org/abs/2004.00445>`__. In DGL, the way to
handle this is:
- Save the weight as edge feature.
- Multiply the edge feature by src node feature in message function.
For example:
.. code::
import dgl.function as fn
graph.edata['a'] = affinity
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
The example above uses affinity as the edge weight. The edge weight should
usually be a scalar.
\ No newline at end of file
.. _guide-message-passing-efficient:
2.2 Writing Efficient Message Passing Code
------------------------------------------
DGL optimizes memory consumption and computing speed for message
passing. The optimization includes:
- Merge multiple kernels in a single one: This is achieved by using
:meth:`~dgl.DGLGraph.update_all` to call multiple built-in functions
at once. (Speed optimization)
- Parallelism on nodes and edges: DGL abstracts edge-wise computation
:meth:`~dgl.DGLGraph.apply_edges` as a generalized sampled dense-dense
matrix multiplication (**gSDDMM**) operation and parallelizes the computing
across edges. Likewise, DGL abstracts node-wise computation
:meth:`~dgl.DGLGraph.update_all` as a generalized sparse-dense matrix
multiplication (**gSPMM**) operation and parallelizes the computing across
nodes. (Speed optimization)
- Avoid unnecessary memory copy from nodes to edges: To generate a
message that requires the feature from source and destination node,
one option is to copy the source and destination node feature to
that edge. For some graphs, the number of edges is much larger than
the number of nodes. This copy can be costly. DGL's built-in message
functions avoid this memory copy by sampling out the node feature using
entry index. (Memory and speed optimization)
- Avoid materializing feature vectors on edges: the complete message
passing process includes message generation, message aggregation and
node update. In :meth:`~dgl.DGLGraph.update_all` call, message function
and reduce function are merged into one kernel if those functions are
built-in. There is no message materialization on edges. (Memory
optimization)
According to the above, a common practise to leverage those
optimizations is to construct one's own message passing functionality as
a combination of :meth:`~dgl.DGLGraph.update_all` calls with built-in
functions as parameters.
For some cases like
:class:`~dgl.nn.pytorch.conv.GATConv`,
where it is necessary to save message on the edges, one needs to call
:meth:`~dgl.DGLGraph.apply_edges` with built-in functions. Sometimes the
messages on the edges can be high dimensional, which is memory consuming.
DGL recommends keeping the dimension of edge features as low as possible.
Here’s an example on how to achieve this by splitting operations on the
edges to nodes. The approach does the following: concatenate the ``src``
feature and ``dst`` feature, then apply a linear layer, i.e.
:math:`W\times (u || v)`. The ``src`` and ``dst`` feature dimension is
high, while the linear layer output dimension is low. A straight forward
implementation would be like:
.. code::
import torch
import torch.nn as nn
linear = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim * 2)))
def concat_message_function(edges):
{'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] * linear
The suggested implementation splits the linear operation into two,
one applies on ``src`` feature, the other applies on ``dst`` feature.
It then adds the output of the linear operations on the edges at the final stage,
i.e. performing :math:`W_l\times u + W_r \times v`. This is because
:math:`W \times (u||v) = W_l \times u + W_r \times v`, where :math:`W_l`
and :math:`W_r` are the left and the right half of the matrix :math:`W`,
respectively:
.. code::
import dgl.function as fn
linear_src = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim)))
out_src = g.ndata['feat'] * linear_src
out_dst = g.ndata['feat'] * linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))
The above two implementations are mathematically equivalent. The latter
one is more efficient because it does not need to save feat_src and
feat_dst on edges, which is not memory-efficient. Plus, addition could
be optimized with DGL’s built-in function ``u_add_v``, which further
speeds up computation and saves memory footprint.
\ No newline at end of file
.. _guide-message-passing-heterograph:
2.5 Message Passing on Heterogeneous Graph
------------------------------------------
Heterogeneous graphs (:ref:`guide-graph-heterogeneous`), or
heterographs for short, are graphs that contain different types of nodes
and edges. The different types of nodes and edges tend to have different
types of attributes that are designed to capture the characteristics of
each node and edge type. Within the context of graph neural networks,
depending on their complexity, certain node and edge types might need to
be modeled with representations that have a different number of
dimensions.
The message passing on heterographs can be split into two parts:
1. Message computation and aggregation for each relation r.
2. Reduction that merges the aggregation results from all relations for each node type.
DGL’s interface to call message passing on heterographs is
:meth:`~dgl.DGLGraph.multi_update_all`.
:meth:`~dgl.DGLGraph.multi_update_all` takes a dictionary containing
the parameters for :meth:`~dgl.DGLGraph.update_all` within each relation
using relation as the key, and a string representing the cross type reducer.
The reducer can be one of ``sum``, ``min``, ``max``, ``mean``, ``stack``.
Here’s an example:
.. code::
import dgl.function as fn
for c_etype in G.canonical_etypes:
srctype, etype, dsttype = c_etype
Wh = self.weight[etype](feat_dict[srctype])
# Save it in graph for message passing
G.nodes[srctype].data['Wh_%s' % etype] = Wh
# Specify per-relation message passing functions: (message_func, reduce_func).
# Note that the results are saved to the same destination feature 'h', which
# hints the type wise reducer for aggregation.
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# Trigger message passing of multiple types.
G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
.. _guide-message-passing-part:
2.3 Apply Message Passing On Part Of The Graph
----------------------------------------------
If one only wants to update part of the nodes in the graph, the practice
is to create a subgraph by providing the IDs for the nodes to
include in the update, then call :meth:`~dgl.DGLGraph.update_all` on the
subgraph. For example:
.. code::
nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)
This is a common usage in mini-batch training. Check :ref:`guide-minibatch` for more detailed
usages.
\ No newline at end of file
.. _guide-message-passing: .. _guide-message-passing:
Chapter 2: Message Passing Chapter 2: Message Passing
================================ ==========================
Message Passing Paradigm Message Passing Paradigm
------------------------ ------------------------
...@@ -22,264 +22,25 @@ with the features of its incident nodes; :math:`\psi` is an ...@@ -22,264 +22,25 @@ with the features of its incident nodes; :math:`\psi` is an
by aggregating its incoming messages using the **reduce function** by aggregating its incoming messages using the **reduce function**
:math:`\rho`. :math:`\rho`.
Built-in Functions and Message Passing APIs Roadmap
------------------------------------------- -------
In DGL, **message function** takes a single argument ``edges``, This chapter introduces DGL's message passing APIs, and how to efficiently use them on both nodes and edges.
which has three members ``src``, ``dst`` and ``data``, to access The last section of it explains how to implement message passing on heterogeneous graphs.
features of source node, destination node, and edge, respectively.
**reduce function** takes a single argument ``nodes``. A node can * :ref:`guide-message-passing-api`
access its ``mailbox`` to collect the messages its neighbors send to it * :ref:`guide-message-passing-efficient`
through edges. Some of the most common reduce operations include ``sum``, * :ref:`guide-message-passing-part`
``max``, ``min``, etc. * :ref:`guide-message-passing-edge`
* :ref:`guide-message-passing-heterograph`
**update function** takes a single argument ``nodes``. This function .. toctree::
operates on the aggregation result from ``reduce function``, typically :maxdepth: 1
combined with a node’s feature at the the last step, and save the output :hidden:
as a node feature. :glob:
DGL has implemented commonly used message functions and reduce functions message-api
as **built-in** in the namespace ``dgl.function``. In general, we message-efficient
suggest using built-in functions **whenever possible** since they are message-part
heavily optimized and automatically handle dimension broadcasting. message-edge
message-heterograph
If your message passing functions cannot be implemented with built-ins,
you can implement user-defined message/reduce function (aka. **UDF**).
Built-in message functions can be unary or binary. We support ``copy``
for unary for now. For binary funcs, we now support ``add``, ``sub``,
``mul``, ``div``, ``dot``. The naming convention for message
built-in funcs is ``u`` represents ``src`` nodes, ``v`` represents
``dst`` nodes, ``e`` represents ``edges``. The parameters for those
functions are strings indicating the input and output field names for
the corresponding nodes and edges. The list of supported built-in functions
can be found in :ref:`api-built-in`. For example, to add the ``hu`` feature from src
nodes and ``hv`` feature from dst nodes then save the result on the edge
at ``he`` field, we can use built-in function
``dgl.function.u_add_v('hu', 'hv', 'he')`` this is equivalent to the
Message UDF:
.. code::
def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
Built-in reduce functions support operations ``sum``, ``max``, ``min``,
``prod`` and ``mean``. Reduce functions usually have two parameters, one
for field name in ``mailbox``, one for field name in destination, both
are strings. For example, ``dgl.function.sum('m', 'h')`` is equivalent
to the Reduce UDF that sums up the message ``m``:
.. code::
import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
In DGL, the interface to call edge-wise computation is
:meth:`~dgl.DGLGraph.apply_edges`.
The parameters for ``apply_edges`` are a message function and valid
edge type as described in the API Doc (by default, all edges will be updated). For
example:
.. code::
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
the interface to call node-wise computation is
:meth:`~dgl.DGLGraph.update_all`.
The parameters for ``update_all`` are a message function, a
reduce function and a update function. update function can
be called outside of ``update_all`` by leaving the third parameter as
empty. This is suggested since the update function can usually be
written as pure tensor operations to make the code concise. For
example:
.. code::
def updata_all_example(graph):
# store the result in graph.ndata['ft']
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# Call update function outside of update_all
final_ft = graph.ndata['ft'] * 2
return final_ft
This call will generate the message ``m`` by multiply src node feature
``ft`` and edge feature ``a``, sum up the message ``m`` to update node
feature ``ft``, finally multiply ``ft`` by 2 to get the result
``final_ft``. After the call, the intermediate message ``m`` will be
cleaned. The math formula for the above function is:
.. math:: {final\_ft}_i = 2 * \sum_{j\in\mathcal{N}(i)} ({ft}_j * a_{ij})
``update_all`` is a high-level API that merges message generation,
message reduction and node update in a single call, which leaves room
for optimizations, as explained below.
Writing Efficient Message Passing Codes
----------------------------------------------
DGL optimized memory consumption and computing speed for message
passing. The optimization includes:
- Merge multiple kernels in a single one: This is achieved by using
``update_all`` to call multiple built-in functions at once.
(Speed optimization)
- Parallelism on nodes and edges: DGL abstracts edge-wise computation
``apply_edges`` as a generalized sampled dense-dense matrix
multiplication (**gSDDMM**) operation and parallelize the computing
across edges. Likewise, DGL abstracts node-wise computation
``update_all`` as a generalized sparse-dense matrix multiplication
(**gSPMM**) operation and parallelize the computing across nodes.
(Speed optimization)
- Avoid unnecessary memory copy into edges: To generate a message that
requires the feature from source and destination node, one option is
to copy the source and destination node feature into that edge. For
some graphs, the number of edges is much larger than the number of
nodes. This copy can be costly. DGL built-in message functions
avoid this memory copy by sampling out the node feature using entry
index. (Memory and speed optimization)
- Avoid materializing feature vectors on edges: the complete message
passing process includes message generation, message reduction and
node update. In ``update_all`` call, message function and reduce
function are merged into one kernel if those functions are
built-in. There is no message materialization on edges. (Memory
optimization)
According to the above, a common practise to leverage those
optimizations is to construct your own message passing functionality as
a combination of ``update_all`` calls with built-in functions as
parameters.
For some cases like
:class:`~dgl.nn.pytorch.conv.GATConv`
where we have to save message on the edges, we need to call
``apply_edges`` with built-in functions. Sometimes the message on
the edges can be high dimensional, which is memory consuming. We suggest
keeping the edata dimension as low as possible.
Here’s an example on how to achieve this by spliting operations on the
edges to nodes. The option does the following: concatenate the ``src``
feature and ``dst`` feature, then apply a linear layer, i.e.
:math:`W\times (u || v)`. The ``src`` and ``dst`` feature dimension is
high, while the linear layer output dimension is low. A straight forward
implementation would be like:
.. code::
linear = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim*2)))
def concat_message_function(edges):
{'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] * linear
The suggested implementation will split the linear operation into two,
one applies on ``src`` feature, the other applies on ``dst`` feature.
Add the output of the linear operations on the edges at the final stage,
i.e. perform :math:`W_l\times u + W_r \times v`, since
:math:`W \times (u||v) = W_l \times u + W_r \times v`, where :math:`W_l`
and :math:`W_r` are the left and the right half of the matrix :math:`W`,
respectively:
.. code::
linear_src = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim)))
linear_dst = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim)))
out_src = g.ndata['feat'] * linear_src
out_dst = g.ndata['feat'] * linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))
The above two implementations are mathematically equivalent. The later
one is much efficient because we do not need to save feat_src and
feat_dst on edges, which is not memory-efficient. Plus, addition could
be optimized with DGL’s built-in function ``u_add_v``, which further
speeds up computation and saves memory footprint.
Apply Message Passing On Part Of The Graph
-----------------------------------------------
If we only want to update part of the nodes in the graph, the practice
is to create a subgraph by providing the ids for the nodes we want to
include in the update, then call ``update_all`` on the subgraph. For
example:
.. code::
nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)
This is a common usage in mini-batch training. Check :ref:`guide-minibatch` user guide for more detailed
usages.
Apply Edge Weight In Message Passing
----------------------------------------
A commonly seen practice in GNN modeling is to apply edge weight on the
message before message aggregation, for examples, in
`GAT <https://arxiv.org/pdf/1710.10903.pdf>`__ and some `GCN
variants <https://arxiv.org/abs/2004.00445>`__. In DGL, the way to
handle this is:
- Save the weight as edge feature.
- Multiply the edge feature with src node feature in message function.
For example:
.. code::
graph.edata['a'] = affinity
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
In the above, we use affinity as the edge weight. The edge weight should
usually be a scalar.
Message Passing on Heterogeneuous Graph
---------------------------------------
Heterogeneous (user guide for :ref:`guide-graph-heterogeneous`), or
heterographs for short, are graphs that contain different types of nodes
and edges. The different types of nodes and edges tend to have different
types of attributes that are designed to capture the characteristics of
each node and edge type. Within the context of graph neural networks,
depending on their complexity, certain node and edge types might need to
be modeled with representations that have a different number of
dimensions.
The message passing on heterographs can be split into two parts:
1. Message computation and aggregation within each relation r.
2. Reduction that merges the results on the same node type from multiple
relationships.
DGL’s interface to call message passing on heterographs is
:meth:`~dgl.DGLGraph.multi_update_all`.
``multi_update_all`` takes a dictionary containing the parameters for
``update_all`` within each relation using relation as the key, and a
string represents the cross type reducer. The reducer can be one of
``sum``, ``min``, ``max``, ``mean``, ``stack``. Here’s an example:
.. code::
for c_etype in G.canonical_etypes:
srctype, etype, dsttype = c_etype
Wh = self.weight[etype](feat_dict[srctype])
# Save it in graph for message passing
G.nodes[srctype].data['Wh_%s' % etype] = Wh
# Specify per-relation message passing functions: (message_func, reduce_func).
# Note that the results are saved to the same destination feature 'h', which
# hints the type wise reducer for aggregation.
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# Trigger message passing of multiple types.
G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
.. _guide-nn-construction:
3.1 DGL NN Module Construction Function
---------------------------------------
The construction function performs the following steps:
1. Set options.
2. Register learnable parameters or submodules.
3. Reset parameters.
.. code::
import torch.nn as nn
from dgl.utils import expand_as_pair
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.activation = activation
In construction function, one first needs to set the data dimensions. For
general PyTorch module, the dimensions are usually input dimension,
output dimension and hidden dimensions. For graph neural, the input
dimension can be split into source node dimension and destination node
dimension.
Besides data dimensions, a typical option for graph neural network is
aggregation type (``self._aggre_type``). Aggregation type determines how
messages on different edges are aggregated for a certain destination
node. Commonly used aggregation types include ``mean``, ``sum``,
``max``, ``min``. Some modules may apply more complicated aggregation
like an ``lstm``.
``norm`` here is a callable function for feature normalization. In the
SAGEConv paper, such normalization can be l2 normalization:
:math:`h_v = h_v / \lVert h_v \rVert_2`.
.. code::
# aggregator type: mean, max_pool, lstm, gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
Register parameters and submodules. In SAGEConv, submodules vary
according to the aggregation type. Those modules are pure PyTorch nn
modules like ``nn.Linear``, ``nn.LSTM``, etc. At the end of construction
function, weight initialization is applied by calling
``reset_parameters()``.
.. code::
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'max_pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
.. _guide-nn-forward:
3.2 DGL NN Module Forward Function
----------------------------------
In NN module, ``forward()`` function does the actual message passing and
computation. Compared with PyTorch’s NN module which usually takes
tensors as the parameters, DGL NN module takes an additional parameter
:class:`dgl.DGLGraph`. The
workload for ``forward()`` function can be split into three parts:
- Graph checking and graph type specification.
- Message passing.
- Feature update.
The rest of the section takes a deep dive into the ``forward()`` function in SAGEConv example.
Graph checking and graph type specification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
def forward(self, graph, feat):
with graph.local_scope():
# Specify graph type then expand input feature according to graph type
feat_src, feat_dst = expand_as_pair(feat, graph)
``forward()`` needs to handle many corner cases on the input that can
lead to invalid values in computing and message passing. One typical check in conv modules
like :class:`~dgl.nn.pytorch.conv.GraphConv` is to verify that the input graph has no 0-in-degree nodes.
When a node has 0 in-degree, the ``mailbox`` will be empty and the reduce function will produce
all-zero values. This may cause silent regression in model performance. However, in
:class:`~dgl.nn.pytorch.conv.SAGEConv` module, the aggregated representation will be concatenated
with the original node feature, the output of ``forward()`` will not be all-zero. No such check is
needed in this case.
DGL NN module should be reusable across different types of graph input
including: homogeneous graph, heterogeneous
graph (:ref:`guide-graph-heterogeneous`), subgraph
block (:ref:`guide-minibatch`).
The math formulas for SAGEConv are:
.. math::
h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate}
\left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right)
.. math::
h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat}
(h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1} + b) \right)
.. math::
h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l})
One needs to specify the source node feature ``feat_src`` and destination
node feature ``feat_dst`` according to the graph type.
:meth:``~dgl.utils.expand_as_pair`` is a function that specifies the graph
type and expand ``feat`` into ``feat_src`` and ``feat_dst``.
The detail of this function is shown below.
.. code::
def expand_as_pair(input_, g=None):
if isinstance(input_, tuple):
# Bipartite graph case
return input_
elif g is not None and g.is_block:
# Subgraph block case
if isinstance(input_, Mapping):
input_dst = {
k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
for k, v in input_.items()}
else:
input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
return input_, input_dst
else:
# Homogeneous graph case
return input_, input_
For homogeneous whole graph training, source nodes and destination nodes
are the same. They are all the nodes in the graph.
For heterogeneous case, the graph can be split into several bipartite
graphs, one for each relation. The relations are represented as
``(src_type, edge_type, dst_dtype)``. When it identifies that the input feature
``feat`` is a tuple, it will treat the graph as bipartite. The first
element in the tuple will be the source node feature and the second
element will be the destination node feature.
In mini-batch training, the computing is applied on a subgraph sampled
based on a bunch of destination nodes. The subgraph is called as
``block`` in DGL. After message passing, only those destination nodes
will be updated since they have the same neighborhood as the one they
have in the original full graph. In the block creation phase,
``dst nodes`` are in the front of the node list. One can find the
``feat_dst`` by the index ``[0:g.number_of_dst_nodes()]``.
After determining ``feat_src`` and ``feat_dst``, the computing for the
above three graph types are the same.
Message passing and reducing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'max_pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
The code actually does message passing and reducing computing. This part
of code varies module by module. Note that all the message passing in
the above code are implemented using :meth:`~dgl.DGLGraph.update_all` API and
``built-in`` message/reduce functions to fully utilize DGL’s performance
optimization as described in :ref:`guide-message-passing-efficient`.
Update feature after reducing for output
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
The last part of ``forward()`` function is to update the feature after
the ``reduce function``. Common update operations are applying
activation function and normalization according to the option set in the
object construction phase.
\ No newline at end of file
.. _guide-nn-heterograph:
3.3 Heterogeneous GraphConv Module
----------------------------------
:class:`~dgl.nn.pytorch.HeteroGraphConv`
is a module-level encapsulation to run DGL NN module on heterogeneous
graphs. The implementation logic is the same as message passing level API
:meth:`~dgl.DGLGraph.multi_update_all`:
- DGL NN module within each relation :math:`r`.
- Reduction that merges the results on the same node type from multiple
relations.
This can be formulated as:
.. math:: h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))
where :math:`f_r` is the NN module for each relation :math:`r`,
:math:`AGG` is the aggregation function.
HeteroGraphConv implementation logic:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
import torch.nn as nn
class HeteroGraphConv(nn.Module):
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = nn.ModuleDict(mods)
if isinstance(aggregate, str):
# An internal function to get common aggregation functions
self.agg_fn = get_aggregate_fn(aggregate)
else:
self.agg_fn = aggregate
The heterograph convolution takes a dictionary ``mods`` that maps each
relation to an nn module and sets the function that aggregates results on
the same node type from multiple relations.
.. code::
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
if mod_args is None:
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
Besides input graph and input tensors, the ``forward()`` function takes
two additional dictionary parameters ``mod_args`` and ``mod_kwargs``.
These two dictionaries have the same keys as ``self.mods``. They are
used as customized parameters when calling their corresponding NN
modules in ``self.mods`` for different types of relations.
An output dictionary is created to hold output tensor for each
destination type ``nty`` . Note that the value for each ``nty`` is a
list, indicating a single node type may get multiple outputs if more
than one relations have ``nty`` as the destination type. ``HeteroGraphConv``
will perform a further aggregation on the lists.
.. code::
if g.is_block:
src_inputs = inputs
dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
src_inputs = dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.num_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
The input ``g`` can be a heterogeneous graph or a subgraph block from a
heterogeneous graph. As in ordinary NN module, the ``forward()``
function need to handle different input graph types separately.
Each relation is represented as a ``canonical_etype``, which is
``(stype, etype, dtype)``. Using ``canonical_etype`` as the key, one can
extract out a bipartite graph ``rel_graph``. For bipartite graph, the
input feature will be organized as a tuple
``(src_inputs[stype], dst_inputs[dtype])``. The NN module for each
relation is called and the output is saved. To avoid unnecessary call,
relations with no edges or no nodes with the src type will be skipped.
.. code::
rsts = {}
for nty, alist in outputs.items():
if len(alist) != 0:
rsts[nty] = self.agg_fn(alist, nty)
Finally, the results on the same destination node type from multiple
relations are aggregated using ``self.agg_fn`` function. Examples can
be found in the API Doc for :class:`~dgl.nn.pytorch.HeteroGraphConv`.
.. _guide-nn: .. _guide-nn:
Chapter 3: Building GNN Modules Chapter 3: Building GNN Modules
===================================== ===============================
DGL NN module is the building block for your GNN model. It inherents DGL NN module consists of building blocks for GNN models. An NN module inherits
from `Pytorch’s NN Module <https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/module.html>`__, `MXNet Gluon’s NN Block <http://mxnet.incubator.apache.org/versions/1.6/api/python/docs/api/gluon/nn/index.html>`__ and `TensorFlow’s Keras from `Pytorch’s NN Module <https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/module.html>`__, `MXNet Gluon’s NN Block <http://mxnet.incubator.apache.org/versions/1.6/api/python/docs/api/gluon/nn/index.html>`__ and `TensorFlow’s Keras
Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers>`__, depending on the DNN framework backend in use. In DGL NN Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers>`__, depending on the DNN framework backend in use. In a DGL NN
module, the parameter registration in construction function and tensor module, the parameter registration in construction function and tensor
operation in forward function are the same with the backend framework. operation in forward function are the same with the backend framework.
In this way, DGL code can be seamlessly integrated into the backend In this way, DGL code can be seamlessly integrated into the backend
...@@ -17,346 +17,21 @@ DGL has integrated many commonly used ...@@ -17,346 +17,21 @@ DGL has integrated many commonly used
and and
:ref:`apinn-pytorch-util`. We welcome your contribution! :ref:`apinn-pytorch-util`. We welcome your contribution!
In this section, we will use This chapter takes :class:`~dgl.nn.pytorch.conv.SAGEConv` with Pytorch backend as an example
:class:`~dgl.nn.pytorch.conv.SAGEConv` to introduce how to build a custom DGL NN Module.
with Pytorch backend as an example to introduce how to build your own
DGL NN Module.
DGL NN Module Construction Function Roadmap
----------------------------------- -------
The construction function will do the following: * :ref:`guide-nn-construction`
* :ref:`guide-nn-forward`
* :ref:`guide-nn-heterograph`
1. Set options. .. toctree::
2. Register learnable paramesters or submodules. :maxdepth: 1
3. Reset parameters. :hidden:
:glob:
.. code:: nn-construction
nn-forward
import torch as th nn-heterograph
from torch import nn
from torch.nn import init
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
bias=True,
norm=None,
activation=None):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.activation = activation
In construction function, we first need to set the data dimensions. For
general Pytorch module, the dimensions are usually input dimension,
output dimension and hidden dimensions. For graph neural, the input
dimension can be split into source node dimension and destination node
dimension.
Besides data dimensions, a typical option for graph neural network is
aggregation type (``self._aggre_type``). Aggregation type determines how
messages on different edges are aggregated for a certain destination
node. Commonly used aggregation types include ``mean``, ``sum``,
``max``, ``min``. Some modules may apply more complicated aggregation
like a ``lstm``.
``norm`` here is a callable function for feature normalization. On the
SAGEConv paper, such normalization can be l2 norm:
:math:`h_v = h_v / \lVert h_v \rVert_2`.
.. code::
# aggregator type: mean, max_pool, lstm, gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
Register parameters and submodules. In SAGEConv, submodules vary
according to the aggregation type. Those modules are pure Pytorch nn
modules like ``nn.Linear``, ``nn.LSTM``, etc. At the end of construction
function, weight initialization is applied by calling
``reset_parameters()``.
.. code::
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'max_pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
DGL NN Module Forward Function
----------------------------------
In NN module, ``forward()`` function does the actual message passing and
computating. Compared with Pytorch’s NN module which usually takes
tensors as the parameters, DGL NN module takes an additional parameter
:class:`dgl.DGLGraph`. The
workload for ``forward()`` function can be splitted into three parts:
- Graph checking and graph type specification.
- Message passing and reducing.
- Update feature after reducing for output.
Let’s dive deep into the ``forward()`` function in SAGEConv example.
Graph checking and graph type specification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
def forward(self, graph, feat):
with graph.local_scope():
# Specify graph type then expand input feature according to graph type
feat_src, feat_dst = expand_as_pair(feat, graph)
``forward()`` needs to handle many corner cases on the input that can
lead to invalid values in computing and message passing. One typical check in conv modules like :class:`~dgl.nn.pytorch.conv.GraphConv` is to verify no 0-in-degree node in the input graph. When a node has 0-in-degree, the ``mailbox`` will be empty and the reduce function will produce all-zero values. This may cause silent regression in model performance. However, in :class:`~dgl.nn.pytorch.conv.SAGEConv` module, the aggregated representation will be concatenated with the original node feature, the output of ``forward()`` will not be all-zero. No such check is needed in this case.
DGL NN module should be reusable across different types of graph input
including: homogeneous graph, heterogeneous
graph (:ref:`guide-graph-heterogeneous`), subgraph
block (:ref:`guide-minibatch`).
The math formulas for SAGEConv are:
.. math::
h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate}
\left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right)
.. math::
h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat}
(h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1} + b) \right)
.. math::
h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l})
We need to specify the source node feature ``feat_src`` and destination
node feature ``feat_dst`` according to the graph type. The function to
specify the graph type and expand ``feat`` into ``feat_src`` and
``feat_dst`` is
``expand_as_pair()``.
The detail of this function is shown below.
.. code::
def expand_as_pair(input_, g=None):
if isinstance(input_, tuple):
# Bipartite graph case
return input_
elif g is not None and g.is_block:
# Subgraph block case
if isinstance(input_, Mapping):
input_dst = {
k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
for k, v in input_.items()}
else:
input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
return input_, input_dst
else:
# Homograph case
return input_, input_
For homogeneous whole graph training, source nodes and destination nodes
are the same. They are all the nodes in the graph.
For heterogeneous case, the graph can be splitted into several bipartite
graphs, one for each relation. The relations are represented as
``(src_type, edge_type, dst_dtype)``. When we identify the input feature
``feat`` is a tuple, we will treat the graph as bipartite. The first
element in the tuple will be the source node feature and the second
element will be the destination node feature.
In mini-batch training, the computing is applied on a subgraph sampled
by given a bunch of destination nodes. The subgraph is called as
``block`` in DGL. After message passing, only those destination nodes
will be updated since they have the same neighborhood as the one they
have in the original full graph. In the block creation phase,
``dst nodes`` are in the front of the node list. We can find the
``feat_dst`` by the index ``[0:g.number_of_dst_nodes()]``.
After determining ``feat_src`` and ``feat_dst``, the computing for the
above three graph types are the same.
Message passing and reducing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'max_pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
The code actually does message passing and reducing computing. This part
of code varies module by module. Note that all the message passings in
the above code are implemented using ``update_all()`` API and
``built-in`` message/reduce functions to fully utilize DGL’s performance
optimization as described in :ref:`guide-message-passing`.
Update feature after reducing for output
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
The last part of ``forward()`` function is to update the feature after
the ``reduce function``. Common update operations are applying
activation function and normalization according to the option set in the
object construction phase.
Heterogeneous GraphConv Module
------------------------------
:class:`dgl.nn.pytorch.HeteroGraphConv`
is a module-level encapsulation to run DGL NN module on heterogeneous
graph. The implementation logic is the same as message passing level API
``multi_update_all()``:
- DGL NN module within each relation :math:`r`.
- Reduction that merges the results on the same node type from multiple
relationships.
This can be formulated as:
.. math:: h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))
where :math:`f_r` is the NN module for each relation :math:`r`,
:math:`AGG` is the aggregation function.
HeteroGraphConv implementation logic:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
class HeteroGraphConv(nn.Module):
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = nn.ModuleDict(mods)
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
self.agg_fn = aggregate
The heterograph convolution takes a dictonary ``mods`` that maps each
relation to a nn module. And set the function that aggregates results on
the same node type from multiple relations.
.. code::
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
if mod_args is None:
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
Besides input graph and input tensors, the ``forward()`` function takes
two additional dictionary parameters ``mod_args`` and ``mod_kwargs``.
These two dictionaries have the same keys as ``self.mods``. They are
used as customized parameters when calling their corresponding NN
modules in ``self.mods``\ for different types of relations.
An output dictionary is created to hold output tensor for each
destination type\ ``nty`` . Note that the value for each ``nty`` is a
list, indicating a single node type may get multiple outputs if more
than one relations have ``nty`` as the destination type. We will hold
them in list for further aggregation.
.. code::
if g.is_block:
src_inputs = inputs
dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
src_inputs = dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
The input ``g`` can be a heterogeneous graph or a subgraph block from a
heterogeneous graph. As in ordinary NN module, the ``forward()``
function need to handle different input graph types separately.
Each relation is represented as a ``canonical_etype``, which is
``(stype, etype, dtype)``. Using ``canonical_etype`` as the key, we can
extract out a bipartite graph ``rel_graph``. For bipartite graph, the
input feature will be organized as a tuple
``(src_inputs[stype], dst_inputs[dtype])``. The NN module for each
relation is called and the output is saved. To avoid unnecessary call,
relations with no edge or no node with the its src type will be skipped.
.. code::
rsts = {}
for nty, alist in outputs.items():
if len(alist) != 0:
rsts[nty] = self.agg_fn(alist, nty)
Finally, the results on the same destination node type from multiple
relationships are aggregated using ``self.agg_fn`` function. Examples can be found in the API Doc for :class:`dgl.nn.pytorch.HeteroGraphConv`.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment