Unverified Commit e262e110 authored by Tianjun Xiao's avatar Tianjun Xiao Committed by GitHub
Browse files

[Doc] Message passing and NN User Guide (#1991)



* add message passing user gude rst

* add nn user guide
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 71f4230a
...@@ -7,6 +7,7 @@ User Guide ...@@ -7,6 +7,7 @@ User Guide
preface preface
graph graph
message message
nn
data data
training training
minibatch minibatch
......
Message Passing and Neural Network Modules Message Passing
============================================ ===============
Message Passing Paradigm
------------------------
Let :math:`x_v\in\mathbb{R}^{d_1}` be the feature for node :math:`v`,
and :math:`w_{e}\in\mathbb{R}^{d_2}` be the feature for edge
:math:`({u}, {v})`. The **message passing paradigm** defines the
following node-wise and edge-wise computation at step :math:`t+1`:
.. math:: \text{Edge-wise: } m_{e}^{(t+1)} = \phi \left( x_v^{(t)}, x_u^{(t)}, w_{e}^{(t)} \right) , ({u}, {v},{e}) \in \mathcal{E}.
.. math:: \text{Node-wise: } x_v^{(t+1)} = \psi \left(x_v^{(t)}, \rho\left(\left\lbrace m_{e}^{(t+1)} : ({u}, {v},{e}) \in \mathcal{E} \right\rbrace \right) \right).
In the above equations, :math:`\phi` is a **message function**
defined on each edge to generate a message by combining the edge feature
with the features of its incident nodes; :math:`\psi` is an
**update function** defined on each node to update the node feature
by aggregating its incoming messages using the **reduce function**
:math:`\rho`.
Built-in Functions and Message Passing APIs
-------------------------------------------
In DGL, **message function** takes a single argument ``edges``,
which has three members ``src``, ``dst`` and ``data``, to access
features of source node, destination node, and edge, respectively.
**reduce function** takes a single argument ``nodes``. A node can
access its ``mailbox`` to collect the messages its neighbors send to it
through edges. Some of the most common reduce operations include ``sum``,
``max``, ``min``, etc.
**update function** takes a single argument ``nodes``. This function
operates on the aggregation result from ``reduce function``, typically
combined with a node’s feature at the the last step, and save the output
as a node feature.
DGL has implemented commonly used message functions and reduce functions
as **built-in** in the namespace ``dgl.function``. In general, we
suggest using built-in functions **whenever possible** since they are
heavily optimized and automatically handle dimension broadcasting.
If your message passing functions cannot be implemented with built-ins,
you can implement user-defined message/reduce function (aka. **UDF**).
Built-in message functions can be unary or binary. We support ``copy``
for unary for now. For binary funcs, we now support ``add``, ``sub``,
``mul``, ``div``, ``dot``. The naming convention for message
built-in funcs is ``u`` represents ``src`` nodes, ``v`` represents
``dst`` nodes, ``e`` represents ``edges``. The parameters for those
functions are strings indicating the input and output field names for
the corresponding nodes and edges. Here is the
`list <https://docs.dgl.ai/api/python/function.html#>`__ of supported
built-in functions. For example, to add the ``hu`` feature from src
nodes and ``hv`` feature from dst nodes then save the result on the edge
at ``he`` field, we can use built-in function
``dgl.function.u_add_v('hu', 'hv', 'he')`` this is equivalent to the
Message UDF:
.. code::
def message_func(edges):
return {'he': edges.src['hu'] + edges.dst['hv']}
Built-in reduce functions support operations ``sum``, ``max``, ``min``,
``prod`` and ``mean``. Reduce functions usually have two parameters, one
for field name in ``mailbox``, one for field name in destination, both
are strings. For example, ``dgl.function.sum('m', 'h')`` is equivalent
to the Reduce UDF that sums up the message ``m``:
.. code::
import torch
def reduce_func(nodes):
return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
In DGL, the interface to call edge-wise computation is
`apply_edges() <https://docs.dgl.ai/generated/dgl.DGLGraph.apply_edges.html>`__.
The parameters for ``apply_edges`` are a message function and valid
edge type (see
`send() <https://docs.dgl.ai/en/0.4.x/generated/dgl.DGLGraph.send.html#dgl.DGLGraph.send>`_
for valid edge types, by default, all edges will be updated). For
example:
.. code::
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
the interface to call node-wise computation is
`update_all() <https://docs.dgl.ai/generated/dgl.DGLGraph.update_all.html>`__.
The parameters for ``update_all`` are a message function, a
reduce function and a update function. update function can
be called outside of ``update_all`` by leaving the third parameter as
empty. This is suggested since the update function can usually be
written as pure tensor operations to make the code concise. For
example:
.. code::
def updata_all_example(graph):
# store the result in graph.ndata['ft']
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
# Call update function outside of update_all
final_ft = graph.ndata['ft'] * 2
return final_ft
This call will generate the message ``m`` by multiply src node feature
``ft`` and edge feature ``a``, sum up the message ``m`` to update node
feature ``ft``, finally multiply ``ft`` by 2 to get the result
``final_ft``. After the call, the intermediate message ``m`` will be
cleaned. The math formula for the above function is:
.. math:: {final\_ft}_i = 2 * \sum_{j\in\mathcal{N}(i)} ({ft}_j * a_{ij})
``update_all`` is a high-level API that merges message generation,
message reduction and node update in a single call, which leaves room
for optimizations, as explained below.
Notes
-----
Performance Optimization Notes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
DGL optimized memory consumption and computing speed for message
passing. The optimization includes:
- Merge multiple kernels in a single one: This is achieved by using
``update_all`` to call multiple built-in functions at once.
(Speed optimization)
- Parallelism on nodes and edges: DGL abstracts edge-wise computation
``apply_edges`` as a generalized sampled dense-dense matrix
multiplication (**gSDDMM**) operation and parallelize the computing
across edges. Likewise, DGL abstracts node-wise computation
``update_all`` as a generalized sparse-dense matrix multiplication
(**gSPMM**) operation and parallelize the computing across nodes.
(Speed optimization)
- Avoid unnecessary memory copy into edges: To generate a message that
requires the feature from source and destination node, one option is
to copy the source and destination node feature into that edge. For
some graphs, the number of edges is much larger than the number of
nodes. This copy can be costly. DGL built-in message functions
avoid this memory copy by sampling out the node feature using entry
index. (Memory and speed optimization)
- Avoid materializing feature vectors on edges: the complete message
passing process includes message generation, message reduction and
node update. In ``update_all`` call, message function and reduce
function are merged into one kernel if those functions are
built-in. There is no message materialization on edges. (Memory
optimization)
According to the above, a common practise to leverage those
optimizations is to construct your own message passing functionality as
a combination of ``update_all`` calls with built-in functions as
parameters.
For some cases like
`GAT <https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/gatconv.py>`__
where we have to save message on the edges, we need to call
``apply_edges`` with built-in functions. Sometimes the message on
the edges can be high dimensional, which is memory consuming. We suggest
keeping the edata dimension as low as possible.
Here’s an example on how to achieve this by spliting operations on the
edges to nodes. The option does the following: concatenate the ``src``
feature and ``dst`` feature, then apply a linear layer, i.e.
:math:`W\times (u || v)`. The ``src`` and ``dst`` feature dimension is
high, while the linear layer output dimension is low. A straight forward
implementation would be like:
.. code::
linear = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim*2)))
def concat_message_function(edges):
{'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] * linear
The suggested implementation will split the linear operation into two,
one applies on ``src`` feature, the other applies on ``dst`` feature.
Add the output of the linear operations on the edges at the final stage,
i.e. perform :math:`W_l\times u + W_r \times v`, since
:math:`W \times (u||v) = W_l \times u + W_r \times v`, where :math:`W_l`
and :math:`W_r` are the left and the right half of the matrix :math:`W`,
respectively:
.. code::
linear_src = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim)))
linear_dst = nn.Parameter(th.FloatTensor(size=(1, node_feat_dim)))
out_src = g.ndata['feat'] * linear_src
out_dst = g.ndata['feat'] * linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))
The above two implementations are mathematically equivalent. The later
one is much efficient because we do not need to save feat_src and
feat_dst on edges, which is not memory-efficient. Plus, addition could
be optimized with DGL’s built-in function ``u_add_v``, which further
speeds up computation and saves memory footprint.
Apply Message Passing On Part Of The Graph
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
If we only want to update part of the nodes in the graph, the practice
is to create a subgraph by providing the ids for the nodes we want to
include in the update, then call ``update_all`` on the subgraph. For
example:
.. code::
nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)
This is a common usage in mini-batch training. Check `mini-batch
training <https://docs.dgl.ai/generated/guide/minibatch.html>`__ user guide for more detailed
usages.
Apply Edge Weight In Message Passing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
A commonly seen practice in GNN modeling is to apply edge weight on the
message before message aggregation, for examples, in
`GAT <https://arxiv.org/pdf/1710.10903.pdf>`__ and some `GCN
variants <https://arxiv.org/abs/2004.00445>`__. In DGL, the way to
handle this is:
- Save the weight as edge feature.
- Multiply the edge feature with src node feature in message function.
For example:
.. code::
graph.edata['a'] = affinity
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
In the above, we use affinity as the edge weight. The edge weight should
usually be a scalar.
Message Passing on Heterogeneuous Graph
---------------------------------------
`Heterogeneous
graphs <https://docs.dgl.ai/tutorials/basics/5_hetero.html>`__, or
heterographs for short, are graphs that contain different types of nodes
and edges. The different types of nodes and edges tend to have different
types of attributes that are designed to capture the characteristics of
each node and edge type. Within the context of graph neural networks,
depending on their complexity, certain node and edge types might need to
be modeled with representations that have a different number of
dimensions.
The message passing on heterographs can be split into two parts:
1. Message computation and aggregation within each relation r.
2. Reduction that merges the results on the same node type from multiple
relationships.
DGL’s interface to call message passing on heterographs is
`multi_update_all() <https://docs.dgl.ai/generated/dgl.DGLHeteroGraph.multi_update_all.html>`__.
``multi_update_all`` takes a dictionary containing the parameters for
``update_all`` within each relation using relation as the key, and a
string represents the cross type reducer. The reducer can be one of
``sum``, ``min``, ``max``, ``mean``, ``stack``. Here’s an example:
.. code::
for c_etype in G.canonical_etypes:
srctype, etype, dsttype = c_etype
Wh = self.weight[etype](feat_dict[srctype])
# Save it in graph for message passing
G.nodes[srctype].data['Wh_%s' % etype] = Wh
# Specify per-relation message passing functions: (message_func, reduce_func).
# Note that the results are saved to the same destination feature 'h', which
# hints the type wise reducer for aggregation.
funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# Trigger message passing of multiple types.
G.multi_update_all(funcs, 'sum')
# return the updated node feature dictionary
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
\ No newline at end of file
Build DGL NN Module
===================
DGL NN module is the building block for your GNN model. It inherents
from `Pytorch’s NN Module <https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/module.html>`__, `MXNet Gluon’s NN Blcok <http://mxnet.incubator.apache.org/versions/1.6/api/python/docs/api/gluon/nn/index.html>`__ and `TensorFlow’s Keras
Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers>`__, depending on the DNN framework backend we are using. In DGL NN
module, the parameter registration in construction function and tensor
operation in forward function are the same with the backend framework.
In this way, DGL code can be seamlessly integrated into the backend
framework code. The major difference lies in the message passing
operations that are unique in DGL.
DGL has integrated many commonly used
`Sparse_GraphConvs <https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.conv>`__,
`Dense_GraphConvs <https://docs.dgl.ai/api/python/nn.pytorch.html#dense-conv-layers>`__,
`Graph_Poolings <https://docs.dgl.ai/api/python/nn.pytorch.html#module-dgl.nn.pytorch.glob>`__,
and
`Utility <https://docs.dgl.ai/api/python/nn.pytorch.html#utility-modules>`__
NN modules. We welcome your contribution!
In this section, we will use
`dgl.nn.conv.SAGEConv <https://github.com/sneakerkg/dgl/blob/nn_doc_refactor/python/dgl/nn/pytorch/conv/sageconv.py>`__
with Pytorch backend as an example to introduce how to build your own
DGL NN Module.
DGL NN Module Construction Function
-----------------------------------
The construction function will do the following:
1. Set options.
2. Register learnable paramesters or submodules.
3. Reset parameters.
.. code::
import torch as th
from torch import nn
from torch.nn import init
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Module):
def __init__(self,
in_feats,
out_feats,
aggregator_type,
bias=True,
norm=None,
activation=None,
allow_zero_in_degree=False):
super(SAGEConv, self).__init__()
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
self._aggre_type = aggregator_type
self.norm = norm
self.activation = activation
self._allow_zero_in_degree = allow_zero_in_degree
In construction function, we first need to set the data dimensions. For
general Pytorch module, the dimensions are usually input dimension,
output dimension and hidden dimensions. For graph neural, the input
dimension can be split into source node dimension and destination node
dimension.
Besides data dimensions, a typical option for graph neural network is
aggregation type (``self._aggre_type``). Aggregation type determines how
messages on different edges are aggregated for a certain destination
node. Commonly used aggregation types include ``mean``, ``sum``,
``max``, ``min``. Some modules may apply more complicated aggregation
like a ``lstm``.
``norm`` here is a callable function for feature normalization. On the
SAGEConv paper, such normalization can be l2 norm:
:math:`h_v = h_v / \lVert h_v \rVert_2`.
.. code::
# aggregator type: mean, max_pool, lstm, gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
if aggregator_type == 'max_pool':
self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
self.reset_parameters()
Register parameters and submodules. In SAGEConv, submodules vary
according to the aggregation type. Those modules are pure Pytorch nn
modules like ``nn.Linear``, ``nn.LSTM``, etc. At the end of construction
function, weight initialization is applied by calling
``reset_parameters()``.
.. code::
def reset_parameters(self):
"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu')
if self._aggre_type == 'max_pool':
nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
if self._aggre_type == 'lstm':
self.lstm.reset_parameters()
if self._aggre_type != 'gcn':
nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
DGL NN Module Forward Function
----------------------------------
In NN module, ``forward()`` function does the actual message passing and
computating. Compared with Pytorch’s NN module which usually takes
tensors as the parameters, DGL NN module takes an additional parameter
`DGLGraph <https://docs.dgl.ai/api/python/graph.html>`__. The
workload for ``forward()`` function can be splitted into three parts:
- Graph checking and graph type specification.
- Message passing and reducing.
- Update feature after reducing for output.
Let’s dive deep into the ``forward()`` function in SAGEConv example.
Graph checking and graph type specification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
def forward(self, graph, feat):
with graph.local_scope():
# Graph checking
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph,
'output for those nodes will be invalid.'
'This is harmful for some applications, '
'causing silent performance regression.'
'Adding self-loop on the input graph by calling
'`g = dgl.add_self_loop(g)` will resolve the issue.'
'Setting ``allow_zero_in_degree`` to be `True`
'when constructing this module will suppress the '
'check and let the code run.')
# Specify graph type then expand input feature according to graph type
feat_src, feat_dst = expand_as_pair(feat, graph)
**This part of code is usually shared by all the NN modules.**
``forward()`` needs to handle many corner cases on the input that can
lead to invalid values in computing and message passing. The above
example handles the case where there are 0-in-degree nodes in the input
graph.
When a node has 0-in-degree, the ``mailbox`` will be empty and the
reduce function will not produce valid values. For example, if the
reduce function is ``max``, the output for the 0-in-degree nodes
will be ``-inf``.
DGL NN module should be reusable across different types of graph input
including: homogeneous graph, `heterogeneous
graph <https://docs.dgl.ai/tutorials/basics/5_hetero.html>`__, `subgraph
block <https://docs.dgl.ai/guide/minibatch.html>`__.
The math formulas for SAGEConv are:
.. math::
h_{\mathcal{N}(dst)}^{(l+1)} = \mathrm{aggregate}
\left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right)
.. math::
h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat}
(h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1} + b) \right)
.. math::
h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l})
We need to specify the source node feature ``feat_src`` and destination
node feature ``feat_dst`` according to the graph type. The function to
specify the graph type and expand ``feat`` into ``feat_src`` and
``feat_dst`` is
`expand_as_pair() <https://github.com/dmlc/dgl/blob/master/python/dgl/utils/internal.py#L553>`__.
The detail of this function is shown below.
.. code::
def expand_as_pair(input_, g=None):
if isinstance(input_, tuple):
# Bipartite graph case
return input_
elif g is not None and g.is_block:
# Subgraph block case
if isinstance(input_, Mapping):
input_dst = {
k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
for k, v in input_.items()}
else:
input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
return input_, input_dst
else:
# Homograph case
return input_, input_
For homogeneous whole graph training, source nodes and destination nodes
are the same. They are all the nodes in the graph.
For heterogeneous case, the graph can be splitted into several bipartite
graphs, one for each relation. The relations are represented as
``(src_type, edge_type, dst_dtype)``. When we identify the input feature
``feat`` is a tuple, we will treat the graph as bipartite. The first
element in the tuple will be the source node feature and the second
element will be the destination node feature.
In mini-batch training, the computing is applied on a subgraph sampled
by given a bunch of destination nodes. The subgraph is called as
``block`` in DGL. After message passing, only those destination nodes
will be updated since they have the same neighborhood as the one they
have in the original full graph. In the block creation phase,
``dst nodes`` are in the front of the node list. We can find the
``feat_dst`` by the index ``[0:g.number_of_dst_nodes()]``.
After determining ``feat_src`` and ``feat_dst``, the computing for the
above three graph types are the same.
Message passing and reducing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
if self._aggre_type == 'mean':
graph.srcdata['h'] = feat_src
graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
elif self._aggre_type == 'gcn':
check_eq_shape(feat)
graph.srcdata['h'] = feat_src
graph.dstdata['h'] = feat_dst # same as above if homogeneous
graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
# divide in_degrees
degs = graph.in_degrees().to(feat_dst)
h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
elif self._aggre_type == 'max_pool':
graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
h_neigh = graph.dstdata['neigh']
else:
raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
# GraphSAGE GCN does not require fc_self.
if self._aggre_type == 'gcn':
rst = self.fc_neigh(h_neigh)
else:
rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
The code actually does message passing and reducing computing. This part
of code varies module by module. Note that all the message passings in
the above code are implemented using ``update_all()`` API and
``built-in`` message/reduce functions to fully utilize DGL’s performance
optimization as described in the `Message Passing User Guide
Section <https://docs.dgl.ai/guide/message.html>`__.
Update feature after reducing for output
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
# activation
if self.activation is not None:
rst = self.activation(rst)
# normalization
if self.norm is not None:
rst = self.norm(rst)
return rst
The last part of ``forward()`` function is to update the feature after
the ``reduce function``. Common update operations are applying
activation function and normalization according to the option set in the
object construction phase.
Heterogeneous GraphConv Module
------------------------------
`HeteroGraphConv <https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/hetero.py>`__
is a module-level encapsulation to run DGL NN module on heterogeneous
graph. The implementation logic is the same as message passing level API
``multi_update_all()``:
- DGL NN module within each relation :math:`r`.
- Reduction that merges the results on the same node type from multiple
relationships.
This can be formulated as:
.. math:: h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))
where :math:`f_r` is the NN module for each relation :math:`r`,
:math:`AGG` is the aggregation function.
HeteroGraphConv implementation logic:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. code::
class HeteroGraphConv(nn.Module):
def __init__(self, mods, aggregate='sum'):
super(HeteroGraphConv, self).__init__()
self.mods = nn.ModuleDict(mods)
if isinstance(aggregate, str):
self.agg_fn = get_aggregate_fn(aggregate)
else:
self.agg_fn = aggregate
The heterograph convolution takes a dictonary ``mods`` that maps each
relation to a nn module. And set the function that aggregates results on
the same node type from multiple relations.
.. code::
def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
if mod_args is None:
mod_args = {}
if mod_kwargs is None:
mod_kwargs = {}
outputs = {nty : [] for nty in g.dsttypes}
Besides input graph and input tensors, the ``forward()`` function takes
two additional dictionary parameters ``mod_args`` and ``mod_kwargs``.
These two dictionaries have the same keys as ``self.mods``. They are
used as customized parameters when calling their corresponding NN
modules in ``self.mods``\ for different types of relations.
An output dictionary is created to hold output tensor for each
destination type\ ``nty`` . Note that the value for each ``nty`` is a
list, indicating a single node type may get multiple outputs if more
than one relations have ``nty`` as the destination type. We will hold
them in list for further aggregation.
.. code::
if g.is_block:
src_inputs = inputs
dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
else:
src_inputs = dst_inputs = inputs
for stype, etype, dtype in g.canonical_etypes:
rel_graph = g[stype, etype, dtype]
if rel_graph.number_of_edges() == 0:
continue
if stype not in src_inputs or dtype not in dst_inputs:
continue
dstdata = self.mods[etype](
rel_graph,
(src_inputs[stype], dst_inputs[dtype]),
*mod_args.get(etype, ()),
**mod_kwargs.get(etype, {}))
outputs[dtype].append(dstdata)
The input ``g`` can be a heterogeneous graph or a subgraph block from a
heterogeneous graph. As in ordinary NN module, the ``forward()``
function need to handle different input graph types separately.
Each relation is represented as a ``canonical_etype``, which is
``(stype, etype, dtype)``. Using ``canonical_etype`` as the key, we can
extract out a bipartite graph ``rel_graph``. For bipartite graph, the
input feature will be organized as a tuple
``(src_inputs[stype], dst_inputs[dtype])``. The NN module for each
relation is called and the output is saved. To avoid unnecessary call,
relations with no edge or no node with the its src type will be skipped.
.. code::
rsts = {}
for nty, alist in outputs.items():
if len(alist) != 0:
rsts[nty] = self.agg_fn(alist, nty)
Finally, the results on the same destination node type from multiple
relationships are aggregated using ``self.agg_fn`` function.
HeteroGraphConv examplar usage code
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Create a heterograph
^^^^^^^^^^^^^^^^^^^^
.. code::
>>> import dgl
>>> g = dgl.heterograph({
>>> ('user', 'follows', 'user') : edges1,
>>> ('user', 'plays', 'game') : edges2,
>>> ('store', 'sells', 'game') : edges3})
This heterograph has three types of relations and nodes.
Create a HeteroGraphConv module
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. code::
>>> import dgl.nn.pytorch as dglnn
>>> conv = dglnn.HeteroGraphConv({
>>> 'follows' : dglnn.GraphConv(...),
>>> 'plays' : dglnn.GraphConv(...),
>>> 'sells' : dglnn.SAGEConv(...)},
>>> aggregate='sum')
This module applies different convolution modules to different
relations. Note that the modules for ``'follows'`` and ``'plays'`` do
not share weights. The ``aggregate`` parameter indicates how results are
aggregated if multiple relations have the same destination node types.
Call forward with different inputs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Case 1: Call forward with some ``'user'`` features. This computes new
features for both ``'user'`` and ``'game'`` nodes.
.. code::
>>> import torch as th
>>> h1 = {'user' : th.randn((g.number_of_nodes('user'), 5))}
>>> h2 = conv(g, h1)
>>> print(h2.keys())
dict_keys(['user', 'game'])
Case 2: Call forward with both ``'user'`` and ``'store'`` features.
.. code::
>>> f1 = {'user' : ..., 'store' : ...}
>>> f2 = conv(g, f1)
>>> print(f2.keys())
dict_keys(['user', 'game'])
Because both the ``'plays'`` and ``'sells'`` relations will update the
``'game'`` features, their results are aggregated by the specified
method (i.e., summation here).
Case 3: Call forward with a pair of inputs.
.. code::
>>> x_src = {'user' : ..., 'store' : ...}
>>> x_dst = {'user' : ..., 'game' : ...}
>>> y_dst = conv(g, (x_src, x_dst))
>>> print(y_dst.keys())
dict_keys(['user', 'game'])
Each submodule will also be invoked with a pair of inputs.
...@@ -40,7 +40,7 @@ Getting Started ...@@ -40,7 +40,7 @@ Getting Started
* :doc:`End-to-end model tutorials<tutorials/models/index>` for learning DGL by popular models on graphs. * :doc:`End-to-end model tutorials<tutorials/models/index>` for learning DGL by popular models on graphs.
.. ..
Follow the :doc:`instructions<install/index>` to install DGL. Follow the :doc:`instructions<install/index>` to install DGL.
:doc:`DGL at a glance<tutorials/basics/1_first>` is the most common place to get started with. :doc:`DGL at a glance<tutorials/basics/1_first>` is the most common place to get started with.
It offers a broad experience of using DGL for deep learning on graph data. It offers a broad experience of using DGL for deep learning on graph data.
...@@ -89,6 +89,7 @@ Getting Started ...@@ -89,6 +89,7 @@ Getting Started
guide/preface guide/preface
guide/graph guide/graph
guide/message guide/message
guide/nn
guide/data guide/data
guide/training guide/training
guide/minibatch guide/minibatch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment