training-graph.rst 10.2 KB
Newer Older
1
2
3
4
5
.. _guide-training-graph-classification:

5.4 Graph Classification
----------------------------------

6
7
:ref:`(中文版) <guide_cn-training-graph-classification>`

Mufei Li's avatar
Mufei Li committed
8
Instead of a big single graph, sometimes one might have the data in the
9
form of multiple graphs, for example a list of different types of
Mufei Li's avatar
Mufei Li committed
10
11
communities of people. By characterizing the friendship among people in
the same community by a graph, one can get a list of graphs to classify. In
12
13
14
15
16
17
18
19
20
this scenario, a graph classification model could help identify the type
of the community, i.e. to classify each graph based on the structure and
overall information.

Overview
~~~~~~~~

The major difference between graph classification and node
classification or link prediction is that the prediction result
Mufei Li's avatar
Mufei Li committed
21
characterizes the property of the entire input graph. One can perform the
22
message passing over nodes/edges just like the previous tasks, but also
Mufei Li's avatar
Mufei Li committed
23
needs to retrieve a graph-level representation.
24

Mufei Li's avatar
Mufei Li committed
25
The graph classification pipeline proceeds as follows:
26
27
28
29
30
31
32
33

.. figure:: https://data.dgl.ai/tutorial/batch/graph_classifier.png
   :alt: Graph Classification Process

   Graph Classification Process

From left to right, the common practice is:

Mufei Li's avatar
Mufei Li committed
34
35
36
37
-  Prepare a batch of graphs
-  Perform message passing on the batched graphs to update node/edge features
-  Aggregate node/edge features into graph-level representations
-  Classify graphs based on graph-level representations
38
39
40
41
42

Batch of Graphs
^^^^^^^^^^^^^^^

Usually a graph classification task trains on a lot of graphs, and it
Mufei Li's avatar
Mufei Li committed
43
will be very inefficient to use only one graph at a time when
44
training the model. Borrowing the idea of mini-batch training from
Mufei Li's avatar
Mufei Li committed
45
common deep learning practice, one can build a batch of multiple graphs
46
47
and send them together for one training iteration.

Mufei Li's avatar
Mufei Li committed
48
49
50
In DGL, one can build a single batched graph from a list of graphs. This
batched graph can be simply used as a single large graph, with connected
components corresponding to the original small graphs.
51
52
53
54
55
56

.. figure:: https://data.dgl.ai/tutorial/batch/batch.png
   :alt: Batched Graph

   Batched Graph

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
The following example calls :func:`dgl.batch` on a list of graphs.
A batched graph is a single graph, while it also carries information
about the list.

.. code:: python

    import dgl
    import torch as th

    g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
    g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))

    bg = dgl.batch([g1, g2])
    bg
    # Graph(num_nodes=7, num_edges=7,
    #       ndata_schemes={}
    #       edata_schemes={})
    bg.batch_size
    # 2
    bg.batch_num_nodes()
    # tensor([4, 3])
    bg.batch_num_edges()
    # tensor([3, 4])
    bg.edges()
    # (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))

83
84
85
86
Please note that most dgl transformation functions will discard the batch information.
In order to maintain such information, please use :func:`dgl.DGLGraph.set_batch_num_nodes`
and :func:`dgl.DGLGraph.set_batch_num_edges` on the transformed graph.

87
88
89
90
Graph Readout
^^^^^^^^^^^^^

Every graph in the data may have its unique structure, as well as its
Mufei Li's avatar
Mufei Li committed
91
92
93
node and edge features. In order to make a single prediction, one usually
aggregates and summarizes over the possibly abundant information. This
type of operation is named *readout*. Common readout operations include
94
95
summation, average, maximum or minimum over all node or edge features.

Mufei Li's avatar
Mufei Li committed
96
Given a graph :math:`g`, one can define the average node feature readout as
97
98
99

.. math:: h_g = \frac{1}{|\mathcal{V}|}\sum_{v\in \mathcal{V}}h_v

Mufei Li's avatar
Mufei Li committed
100
101
where :math:`h_g` is the representation of :math:`g`, :math:`\mathcal{V}` is
the set of nodes in :math:`g`, :math:`h_v` is the feature of node :math:`v`.
102

Mufei Li's avatar
Mufei Li committed
103
104
105
106
DGL provides built-in support for common readout operations. For example,
:func:`dgl.readout_nodes` implements the above readout operation.

Once :math:`h_g` is available, one can pass it through an MLP layer for
107
108
classification output.

Mufei Li's avatar
Mufei Li committed
109
Writing Neural Network Model
110
111
112
113
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The input to the model is the batched graph with node and edge features.

Mufei Li's avatar
Mufei Li committed
114
Computation on a Batched Graph
115
116
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Mufei Li's avatar
Mufei Li committed
117
118
First, different graphs in a batch are entirely separated, i.e. no edges
between any two graphs. With this nice property, all message passing
119
120
121
functions still have the same results.

Second, the readout function on a batched graph will be conducted over
Mufei Li's avatar
Mufei Li committed
122
each graph separately. Assuming the batch size is :math:`B` and the
123
124
125
126
127
feature to be aggregated has dimension :math:`D`, the shape of the
readout result will be :math:`(B, D)`.

.. code:: python

Mufei Li's avatar
Mufei Li committed
128
129
130
    import dgl
    import torch

131
132
133
134
135
136
137
138
139
140
141
142
    g1 = dgl.graph(([0, 1], [1, 0]))
    g1.ndata['h'] = torch.tensor([1., 2.])
    g2 = dgl.graph(([0, 1], [1, 2]))
    g2.ndata['h'] = torch.tensor([1., 2., 3.])
    
    dgl.readout_nodes(g1, 'h')
    # tensor([3.])  # 1 + 2
    
    bg = dgl.batch([g1, g2])
    dgl.readout_nodes(bg, 'h')
    # tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]

Mufei Li's avatar
Mufei Li committed
143
144
Finally, each node/edge feature in a batched graph is obtained by
concatenating the corresponding features from all graphs in order.
145
146
147
148
149
150

.. code:: python

    bg.ndata['h']
    # tensor([1., 2., 1., 2., 3.])

Mufei Li's avatar
Mufei Li committed
151
Model Definition
152
153
^^^^^^^^^^^^^^^^

Mufei Li's avatar
Mufei Li committed
154
Being aware of the above computation rules, one can define a model as follows.
155
156
157

.. code:: python

Mufei Li's avatar
Mufei Li committed
158
159
160
    import dgl.nn.pytorch as dglnn
    import torch.nn as nn

161
162
163
164
165
166
167
    class Classifier(nn.Module):
        def __init__(self, in_dim, hidden_dim, n_classes):
            super(Classifier, self).__init__()
            self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
            self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
            self.classify = nn.Linear(hidden_dim, n_classes)
    
Mufei Li's avatar
Mufei Li committed
168
        def forward(self, g, h):
169
170
171
172
173
174
175
176
177
            # Apply graph convolution and activation.
            h = F.relu(self.conv1(g, h))
            h = F.relu(self.conv2(g, h))
            with g.local_scope():
                g.ndata['h'] = h
                # Calculate graph representation by average readout.
                hg = dgl.mean_nodes(g, 'h')
                return self.classify(hg)

Mufei Li's avatar
Mufei Li committed
178
Training Loop
179
180
181
182
183
~~~~~~~~~~~~~

Data Loading
^^^^^^^^^^^^

Mufei Li's avatar
Mufei Li committed
184
185
186
Once the model is defined, one can start training. Since graph
classification deals with lots of relatively small graphs instead of a big
single one, one can train efficiently on stochastic mini-batches
187
188
189
of graphs, without the need to design sophisticated graph sampling
algorithms.

Mufei Li's avatar
Mufei Li committed
190
Assuming that one have a graph classification dataset as introduced in
191
192
193
194
195
196
197
198
:ref:`guide-data-pipeline`.

.. code:: python

    import dgl.data
    dataset = dgl.data.GINDataset('MUTAG', False)

Each item in the graph classification dataset is a pair of a graph and
Mufei Li's avatar
Mufei Li committed
199
its label. One can speed up the data loading process by taking advantage
200
of the GraphDataLoader to iterate over the dataset of
Mufei Li's avatar
Mufei Li committed
201
graphs in mini-batches.
202
203
204

.. code:: python

205
206
    from dgl.dataloading import GraphDataLoader
    dataloader = GraphDataLoader(
207
208
209
210
211
212
213
214
215
216
        dataset,
        batch_size=1024,
        drop_last=False,
        shuffle=True)

Training loop then simply involves iterating over the dataloader and
updating the model.

.. code:: python

Mufei Li's avatar
Mufei Li committed
217
218
219
220
    import torch.nn.functional as F

    # Only an example, 7 is the input feature size
    model = Classifier(7, 20, 5)
221
222
223
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(20):
        for batched_graph, labels in dataloader:
224
            feats = batched_graph.ndata['attr']
225
226
227
228
229
230
            logits = model(batched_graph, feats)
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

Mufei Li's avatar
Mufei Li committed
231
232
233
For an end-to-end example of graph classification, see
`DGL's GIN example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__. 
The training loop is inside the
234
function ``train`` in
235
`main.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/main.py>`__.
236
The model implementation is inside
237
`gin.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/gin.py>`__
238
239
240
241
242
243
244
245
with more components such as using
:class:`dgl.nn.pytorch.GINConv` (also available in MXNet and Tensorflow)
as the graph convolution layer, batch normalization, etc.

Heterogeneous graph
~~~~~~~~~~~~~~~~~~~

Graph classification with heterogeneous graphs is a little different
Mufei Li's avatar
Mufei Li committed
246
247
from that with homogeneous graphs. In addition to graph convolution modules
compatible with heterogeneous graphs, one also needs to aggregate over the nodes of
248
249
250
251
252
253
254
255
256
257
different types in the readout function.

The following shows an example of summing up the average of node
representations for each node type.

.. code:: python

    class RGCN(nn.Module):
        def __init__(self, in_feats, hid_feats, out_feats, rel_names):
            super().__init__()
258
    
259
260
261
262
263
264
            self.conv1 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(in_feats, hid_feats)
                for rel in rel_names}, aggregate='sum')
            self.conv2 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(hid_feats, out_feats)
                for rel in rel_names}, aggregate='sum')
265
    
266
        def forward(self, graph, inputs):
Mufei Li's avatar
Mufei Li committed
267
            # inputs is features of nodes
268
269
270
271
272
273
274
275
            h = self.conv1(graph, inputs)
            h = {k: F.relu(v) for k, v in h.items()}
            h = self.conv2(graph, h)
            return h
    
    class HeteroClassifier(nn.Module):
        def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
            super().__init__()
276
277

            self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
278
279
280
281
            self.classify = nn.Linear(hidden_dim, n_classes)
    
        def forward(self, g):
            h = g.ndata['feat']
282
            h = self.rgcn(g, h)
283
284
285
286
287
288
289
290
291
292
293
294
            with g.local_scope():
                g.ndata['h'] = h
                # Calculate graph representation by average readout.
                hg = 0
                for ntype in g.ntypes:
                    hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
                return self.classify(hg)

The rest of the code is not different from that for homogeneous graphs.

.. code:: python

295
296
    # etypes is the list of edge types as strings.
    model = HeteroClassifier(10, 20, 5, etypes)
297
298
299
300
301
302
303
304
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(20):
        for batched_graph, labels in dataloader:
            logits = model(batched_graph)
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()