training-graph.rst 10.3 KB
Newer Older
1
2
3
4
5
.. _guide-training-graph-classification:

5.4 Graph Classification
----------------------------------

6
7
:ref:`(中文版) <guide_cn-training-graph-classification>`

Mufei Li's avatar
Mufei Li committed
8
Instead of a big single graph, sometimes one might have the data in the
9
form of multiple graphs, for example a list of different types of
Mufei Li's avatar
Mufei Li committed
10
11
communities of people. By characterizing the friendship among people in
the same community by a graph, one can get a list of graphs to classify. In
12
13
14
15
16
17
18
19
20
this scenario, a graph classification model could help identify the type
of the community, i.e. to classify each graph based on the structure and
overall information.

Overview
~~~~~~~~

The major difference between graph classification and node
classification or link prediction is that the prediction result
Mufei Li's avatar
Mufei Li committed
21
characterizes the property of the entire input graph. One can perform the
22
message passing over nodes/edges just like the previous tasks, but also
Mufei Li's avatar
Mufei Li committed
23
needs to retrieve a graph-level representation.
24

Mufei Li's avatar
Mufei Li committed
25
The graph classification pipeline proceeds as follows:
26
27
28
29
30
31
32
33

.. figure:: https://data.dgl.ai/tutorial/batch/graph_classifier.png
   :alt: Graph Classification Process

   Graph Classification Process

From left to right, the common practice is:

Mufei Li's avatar
Mufei Li committed
34
35
36
37
-  Prepare a batch of graphs
-  Perform message passing on the batched graphs to update node/edge features
-  Aggregate node/edge features into graph-level representations
-  Classify graphs based on graph-level representations
38
39
40
41
42

Batch of Graphs
^^^^^^^^^^^^^^^

Usually a graph classification task trains on a lot of graphs, and it
Mufei Li's avatar
Mufei Li committed
43
will be very inefficient to use only one graph at a time when
44
training the model. Borrowing the idea of mini-batch training from
Mufei Li's avatar
Mufei Li committed
45
common deep learning practice, one can build a batch of multiple graphs
46
47
and send them together for one training iteration.

Mufei Li's avatar
Mufei Li committed
48
49
50
In DGL, one can build a single batched graph from a list of graphs. This
batched graph can be simply used as a single large graph, with connected
components corresponding to the original small graphs.
51
52
53
54
55
56

.. figure:: https://data.dgl.ai/tutorial/batch/batch.png
   :alt: Batched Graph

   Batched Graph

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
The following example calls :func:`dgl.batch` on a list of graphs.
A batched graph is a single graph, while it also carries information
about the list.

.. code:: python

    import dgl
    import torch as th

    g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
    g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))

    bg = dgl.batch([g1, g2])
    bg
    # Graph(num_nodes=7, num_edges=7,
    #       ndata_schemes={}
    #       edata_schemes={})
    bg.batch_size
    # 2
    bg.batch_num_nodes()
    # tensor([4, 3])
    bg.batch_num_edges()
    # tensor([3, 4])
    bg.edges()
    # (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))

83
84
85
86
Graph Readout
^^^^^^^^^^^^^

Every graph in the data may have its unique structure, as well as its
Mufei Li's avatar
Mufei Li committed
87
88
89
node and edge features. In order to make a single prediction, one usually
aggregates and summarizes over the possibly abundant information. This
type of operation is named *readout*. Common readout operations include
90
91
summation, average, maximum or minimum over all node or edge features.

Mufei Li's avatar
Mufei Li committed
92
Given a graph :math:`g`, one can define the average node feature readout as
93
94
95

.. math:: h_g = \frac{1}{|\mathcal{V}|}\sum_{v\in \mathcal{V}}h_v

Mufei Li's avatar
Mufei Li committed
96
97
where :math:`h_g` is the representation of :math:`g`, :math:`\mathcal{V}` is
the set of nodes in :math:`g`, :math:`h_v` is the feature of node :math:`v`.
98

Mufei Li's avatar
Mufei Li committed
99
100
101
102
DGL provides built-in support for common readout operations. For example,
:func:`dgl.readout_nodes` implements the above readout operation.

Once :math:`h_g` is available, one can pass it through an MLP layer for
103
104
classification output.

Mufei Li's avatar
Mufei Li committed
105
Writing Neural Network Model
106
107
108
109
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The input to the model is the batched graph with node and edge features.

Mufei Li's avatar
Mufei Li committed
110
Computation on a Batched Graph
111
112
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Mufei Li's avatar
Mufei Li committed
113
114
First, different graphs in a batch are entirely separated, i.e. no edges
between any two graphs. With this nice property, all message passing
115
116
117
functions still have the same results.

Second, the readout function on a batched graph will be conducted over
Mufei Li's avatar
Mufei Li committed
118
each graph separately. Assuming the batch size is :math:`B` and the
119
120
121
122
123
feature to be aggregated has dimension :math:`D`, the shape of the
readout result will be :math:`(B, D)`.

.. code:: python

Mufei Li's avatar
Mufei Li committed
124
125
126
    import dgl
    import torch

127
128
129
130
131
132
133
134
135
136
137
138
    g1 = dgl.graph(([0, 1], [1, 0]))
    g1.ndata['h'] = torch.tensor([1., 2.])
    g2 = dgl.graph(([0, 1], [1, 2]))
    g2.ndata['h'] = torch.tensor([1., 2., 3.])
    
    dgl.readout_nodes(g1, 'h')
    # tensor([3.])  # 1 + 2
    
    bg = dgl.batch([g1, g2])
    dgl.readout_nodes(bg, 'h')
    # tensor([3., 6.])  # [1 + 2, 1 + 2 + 3]

Mufei Li's avatar
Mufei Li committed
139
140
Finally, each node/edge feature in a batched graph is obtained by
concatenating the corresponding features from all graphs in order.
141
142
143
144
145
146

.. code:: python

    bg.ndata['h']
    # tensor([1., 2., 1., 2., 3.])

Mufei Li's avatar
Mufei Li committed
147
Model Definition
148
149
^^^^^^^^^^^^^^^^

Mufei Li's avatar
Mufei Li committed
150
Being aware of the above computation rules, one can define a model as follows.
151
152
153

.. code:: python

Mufei Li's avatar
Mufei Li committed
154
155
156
    import dgl.nn.pytorch as dglnn
    import torch.nn as nn

157
158
159
160
161
162
163
    class Classifier(nn.Module):
        def __init__(self, in_dim, hidden_dim, n_classes):
            super(Classifier, self).__init__()
            self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
            self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
            self.classify = nn.Linear(hidden_dim, n_classes)
    
Mufei Li's avatar
Mufei Li committed
164
        def forward(self, g, h):
165
166
167
168
169
170
171
172
173
            # Apply graph convolution and activation.
            h = F.relu(self.conv1(g, h))
            h = F.relu(self.conv2(g, h))
            with g.local_scope():
                g.ndata['h'] = h
                # Calculate graph representation by average readout.
                hg = dgl.mean_nodes(g, 'h')
                return self.classify(hg)

Mufei Li's avatar
Mufei Li committed
174
Training Loop
175
176
177
178
179
~~~~~~~~~~~~~

Data Loading
^^^^^^^^^^^^

Mufei Li's avatar
Mufei Li committed
180
181
182
Once the model is defined, one can start training. Since graph
classification deals with lots of relatively small graphs instead of a big
single one, one can train efficiently on stochastic mini-batches
183
184
185
of graphs, without the need to design sophisticated graph sampling
algorithms.

Mufei Li's avatar
Mufei Li committed
186
Assuming that one have a graph classification dataset as introduced in
187
188
189
190
191
192
193
194
:ref:`guide-data-pipeline`.

.. code:: python

    import dgl.data
    dataset = dgl.data.GINDataset('MUTAG', False)

Each item in the graph classification dataset is a pair of a graph and
Mufei Li's avatar
Mufei Li committed
195
its label. One can speed up the data loading process by taking advantage
196
197
198
199
200
201
202
203
204
205
206
207
of the DataLoader, by customizing the collate function to batch the
graphs:

.. code:: python

    def collate(samples):
        graphs, labels = map(list, zip(*samples))
        batched_graph = dgl.batch(graphs)
        batched_labels = torch.tensor(labels)
        return batched_graph, batched_labels

Then one can create a DataLoader that iterates over the dataset of
Mufei Li's avatar
Mufei Li committed
208
graphs in mini-batches.
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

.. code:: python

    from torch.utils.data import DataLoader
    dataloader = DataLoader(
        dataset,
        batch_size=1024,
        collate_fn=collate,
        drop_last=False,
        shuffle=True)

Training loop then simply involves iterating over the dataloader and
updating the model.

.. code:: python

Mufei Li's avatar
Mufei Li committed
225
226
227
228
    import torch.nn.functional as F

    # Only an example, 7 is the input feature size
    model = Classifier(7, 20, 5)
229
230
231
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(20):
        for batched_graph, labels in dataloader:
Mufei Li's avatar
Mufei Li committed
232
            feats = batched_graph.ndata['attr'].float()
233
234
235
236
237
238
            logits = model(batched_graph, feats)
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

Mufei Li's avatar
Mufei Li committed
239
240
241
For an end-to-end example of graph classification, see
`DGL's GIN example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__. 
The training loop is inside the
242
function ``train`` in
243
`main.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/main.py>`__.
244
The model implementation is inside
245
`gin.py <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gin/gin.py>`__
246
247
248
249
250
251
252
253
with more components such as using
:class:`dgl.nn.pytorch.GINConv` (also available in MXNet and Tensorflow)
as the graph convolution layer, batch normalization, etc.

Heterogeneous graph
~~~~~~~~~~~~~~~~~~~

Graph classification with heterogeneous graphs is a little different
Mufei Li's avatar
Mufei Li committed
254
255
from that with homogeneous graphs. In addition to graph convolution modules
compatible with heterogeneous graphs, one also needs to aggregate over the nodes of
256
257
258
259
260
261
262
263
264
265
different types in the readout function.

The following shows an example of summing up the average of node
representations for each node type.

.. code:: python

    class RGCN(nn.Module):
        def __init__(self, in_feats, hid_feats, out_feats, rel_names):
            super().__init__()
266
    
267
268
269
270
271
272
            self.conv1 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(in_feats, hid_feats)
                for rel in rel_names}, aggregate='sum')
            self.conv2 = dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(hid_feats, out_feats)
                for rel in rel_names}, aggregate='sum')
273
    
274
        def forward(self, graph, inputs):
Mufei Li's avatar
Mufei Li committed
275
            # inputs is features of nodes
276
277
278
279
280
281
282
283
            h = self.conv1(graph, inputs)
            h = {k: F.relu(v) for k, v in h.items()}
            h = self.conv2(graph, h)
            return h
    
    class HeteroClassifier(nn.Module):
        def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
            super().__init__()
284
285

            self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
286
287
288
289
            self.classify = nn.Linear(hidden_dim, n_classes)
    
        def forward(self, g):
            h = g.ndata['feat']
290
            h = self.rgcn(g, h)
291
292
293
294
295
296
297
298
299
300
301
302
            with g.local_scope():
                g.ndata['h'] = h
                # Calculate graph representation by average readout.
                hg = 0
                for ntype in g.ntypes:
                    hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
                return self.classify(hg)

The rest of the code is not different from that for homogeneous graphs.

.. code:: python

303
304
    # etypes is the list of edge types as strings.
    model = HeteroClassifier(10, 20, 5, etypes)
305
306
307
308
309
310
311
312
    opt = torch.optim.Adam(model.parameters())
    for epoch in range(20):
        for batched_graph, labels in dataloader:
            logits = model(batched_graph)
            loss = F.cross_entropy(logits, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()