Unverified Commit 7e0107c3 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files
parent 7a6d6668
......@@ -3,10 +3,10 @@
5.4 Graph Classification
----------------------------------
Instead of a big single graph, sometimes we might have the data in the
Instead of a big single graph, sometimes one might have the data in the
form of multiple graphs, for example a list of different types of
communities of people. By characterizing the friendships among people in
the same community by a graph, we get a list of graphs to classify. In
communities of people. By characterizing the friendship among people in
the same community by a graph, one can get a list of graphs to classify. In
this scenario, a graph classification model could help identify the type
of the community, i.e. to classify each graph based on the structure and
overall information.
......@@ -16,11 +16,11 @@ Overview
The major difference between graph classification and node
classification or link prediction is that the prediction result
characterize the property of the entire input graph. We perform the
characterizes the property of the entire input graph. One can perform the
message passing over nodes/edges just like the previous tasks, but also
try to retrieve a graph-level representation.
needs to retrieve a graph-level representation.
The graph classification proceeds as follows:
The graph classification pipeline proceeds as follows:
.. figure:: https://data.dgl.ai/tutorial/batch/graph_classifier.png
:alt: Graph Classification Process
......@@ -29,23 +29,23 @@ The graph classification proceeds as follows:
From left to right, the common practice is:
- Prepare graphs in to a batch of graphs
- Message passing on the batched graphs to update node/edge features
- Aggregate node/edge features into a graph-level representation
- Classification head for the task
- Prepare a batch of graphs
- Perform message passing on the batched graphs to update node/edge features
- Aggregate node/edge features into graph-level representations
- Classify graphs based on graph-level representations
Batch of Graphs
^^^^^^^^^^^^^^^
Usually a graph classification task trains on a lot of graphs, and it
will be very inefficient if we use only one graph at a time when
will be very inefficient to use only one graph at a time when
training the model. Borrowing the idea of mini-batch training from
common deep learning practice, we can build a batch of multiple graphs
common deep learning practice, one can build a batch of multiple graphs
and send them together for one training iteration.
In DGL, we can build a single batched graph of a list of graphs. This
batched graph can be simply used as a single large graph, with separated
components representing the corresponding original small graphs.
In DGL, one can build a single batched graph from a list of graphs. This
batched graph can be simply used as a single large graph, with connected
components corresponding to the original small graphs.
.. figure:: https://data.dgl.ai/tutorial/batch/batch.png
:alt: Batched Graph
......@@ -56,45 +56,46 @@ Graph Readout
^^^^^^^^^^^^^
Every graph in the data may have its unique structure, as well as its
node and edge features. In order to make a single prediction, we usually
aggregate and summarize over the possibly abundant information. This
type of operation is named *Readout*. Common aggregations include
node and edge features. In order to make a single prediction, one usually
aggregates and summarizes over the possibly abundant information. This
type of operation is named *readout*. Common readout operations include
summation, average, maximum or minimum over all node or edge features.
Given a graph :math:`g`, we can define the average readout aggregation
as
Given a graph :math:`g`, one can define the average node feature readout as
.. math:: h_g = \frac{1}{|\mathcal{V}|}\sum_{v\in \mathcal{V}}h_v
In DGL the corresponding function call is :func:`dgl.readout_nodes`.
where :math:`h_g` is the representation of :math:`g`, :math:`\mathcal{V}` is
the set of nodes in :math:`g`, :math:`h_v` is the feature of node :math:`v`.
Once :math:`h_g` is available, we can pass it through an MLP layer for
DGL provides built-in support for common readout operations. For example,
:func:`dgl.readout_nodes` implements the above readout operation.
Once :math:`h_g` is available, one can pass it through an MLP layer for
classification output.
Writing neural network model
Writing Neural Network Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The input to the model is the batched graph with node and edge features.
One thing to note is the node and edge features in the batched graph
have no batch dimension. A little special care should be put in the
model:
Computation on a batched graph
Computation on a Batched Graph
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Next, we discuss the computational properties of a batched graph.
First, different graphs in a batch are entirely separated, i.e. no edge
connecting two graphs. With this nice property, all message passing
First, different graphs in a batch are entirely separated, i.e. no edges
between any two graphs. With this nice property, all message passing
functions still have the same results.
Second, the readout function on a batched graph will be conducted over
each graph separately. Assume the batch size is :math:`B` and the
each graph separately. Assuming the batch size is :math:`B` and the
feature to be aggregated has dimension :math:`D`, the shape of the
readout result will be :math:`(B, D)`.
.. code:: python
import dgl
import torch
g1 = dgl.graph(([0, 1], [1, 0]))
g1.ndata['h'] = torch.tensor([1., 2.])
g2 = dgl.graph(([0, 1], [1, 2]))
......@@ -107,23 +108,24 @@ readout result will be :math:`(B, D)`.
dgl.readout_nodes(bg, 'h')
# tensor([3., 6.]) # [1 + 2, 1 + 2 + 3]
Finally, each node/edge feature tensor on a batched graph is in the
format of concatenating the corresponding feature tensor from all
graphs.
Finally, each node/edge feature in a batched graph is obtained by
concatenating the corresponding features from all graphs in order.
.. code:: python
bg.ndata['h']
# tensor([1., 2., 1., 2., 3.])
Model definition
Model Definition
^^^^^^^^^^^^^^^^
Being aware of the above computation rules, we can define a very simple
model.
Being aware of the above computation rules, one can define a model as follows.
.. code:: python
import dgl.nn.pytorch as dglnn
import torch.nn as nn
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
......@@ -131,7 +133,7 @@ model.
self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g, feat):
def forward(self, g, h):
# Apply graph convolution and activation.
h = F.relu(self.conv1(g, h))
h = F.relu(self.conv2(g, h))
......@@ -141,19 +143,19 @@ model.
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
Training loop
Training Loop
~~~~~~~~~~~~~
Data Loading
^^^^^^^^^^^^
Once the models defined, we can start training. Since graph
classification deals with lots of relative small graphs instead of a big
single one, we usually can train efficiently on stochastic mini-batches
Once the model is defined, one can start training. Since graph
classification deals with lots of relatively small graphs instead of a big
single one, one can train efficiently on stochastic mini-batches
of graphs, without the need to design sophisticated graph sampling
algorithms.
Assuming that we have a graph classification dataset as introduced in
Assuming that one have a graph classification dataset as introduced in
:ref:`guide-data-pipeline`.
.. code:: python
......@@ -162,7 +164,7 @@ Assuming that we have a graph classification dataset as introduced in
dataset = dgl.data.GINDataset('MUTAG', False)
Each item in the graph classification dataset is a pair of a graph and
its label. We can speed up the data loading process by taking advantage
its label. One can speed up the data loading process by taking advantage
of the DataLoader, by customizing the collate function to batch the
graphs:
......@@ -175,7 +177,7 @@ graphs:
return batched_graph, batched_labels
Then one can create a DataLoader that iterates over the dataset of
graphs in minibatches.
graphs in mini-batches.
.. code:: python
......@@ -195,20 +197,23 @@ updating the model.
.. code:: python
model = Classifier(10, 20, 5)
import torch.nn.functional as F
# Only an example, 7 is the input feature size
model = Classifier(7, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
for batched_graph, labels in dataloader:
feats = batched_graph.ndata['feats']
feats = batched_graph.ndata['attr'].float()
logits = model(batched_graph, feats)
loss = F.cross_entropy(logits, labels)
opt.zero_grad()
loss.backward()
opt.step()
DGL implements
`GIN <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__
as an example of graph classification. The training loop is inside the
For an end-to-end example of graph classification, see
`DGL's GIN example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__.
The training loop is inside the
function ``train`` in
`main.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/main.py>`__.
The model implementation is inside
......@@ -221,8 +226,8 @@ Heterogeneous graph
~~~~~~~~~~~~~~~~~~~
Graph classification with heterogeneous graphs is a little different
from that with homogeneous graphs. Except that you need heterogeneous
graph convolution modules, yoyu also need to aggregate over the nodes of
from that with homogeneous graphs. In addition to graph convolution modules
compatible with heterogeneous graphs, one also needs to aggregate over the nodes of
different types in the readout function.
The following shows an example of summing up the average of node
......@@ -242,7 +247,7 @@ representations for each node type.
for rel in rel_names}, aggregate='sum')
def forward(self, graph, inputs):
# inputs are features of nodes
# inputs is features of nodes
h = self.conv1(graph, inputs)
h = {k: F.relu(v) for k, v in h.items()}
h = self.conv2(graph, h)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment