"git@developer.sourcefind.cn:change/sglang.git" did not exist on "2a02185c5f9be353fb493fc3548552ec5a5aafad"
5_graph_classification.py 7.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
Training a GNN for Graph Classification
=======================================

By the end of this tutorial, you will be able to

-  Load a DGL-provided graph classification dataset.
-  Understand what *readout* function does.
-  Understand how to create and use a minibatch of graphs.
-  Build a GNN-based graph classification model.
-  Train and evaluate the model on a DGL-provided dataset.

(Time estimate: 18 minutes)
"""

16
import os
17

18
os.environ["DGLBACKEND"] = "pytorch"
19
20
import dgl
import dgl.data
21
22
23
import torch
import torch.nn as nn
import torch.nn.functional as F
24
25
26
27

######################################################################
# Overview of Graph Classification with GNN
# -----------------------------------------
28
#
29
30
31
# Graph classification or regression requires a model to predict certain
# graph-level properties of a single graph given its node and edge
# features.  Molecular property prediction is one particular application.
32
#
33
34
35
# This tutorial shows how to train a graph classification model for a
# small dataset from the paper `How Powerful Are Graph Neural
# Networks <https://arxiv.org/abs/1810.00826>`__.
36
#
37
38
# Loading Data
# ------------
39
#
40
41
42


# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
43
dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)
44
45
46
47
48
49
50


######################################################################
# The dataset is a set of graphs, each with node features and a single
# label. One can see the node feature dimensionality and the number of
# possible graph categories of ``GINDataset`` objects in ``dim_nfeats``
# and ``gclasses`` attributes.
51
#
52

53
54
print("Node feature dimensionality:", dataset.dim_nfeats)
print("Number of graph categories:", dataset.gclasses)
55
56


57
58
from dgl.dataloading import GraphDataLoader

59
60
61
######################################################################
# Defining Data Loader
# --------------------
62
#
63
64
65
66
67
68
# A graph classification dataset usually contains two types of elements: a
# set of graphs, and their graph-level labels. Similar to an image
# classification task, when the dataset is large enough, we need to train
# with mini-batches. When you train a model for image classification or
# language modeling, you will use a ``DataLoader`` to iterate over the
# dataset. In DGL, you can use the ``GraphDataLoader``.
69
#
70
# You can also use various dataset samplers provided in
71
# `torch.utils.data.sampler <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`__.
72
73
74
# For example, this tutorial creates a training ``GraphDataLoader`` and
# test ``GraphDataLoader``, using ``SubsetRandomSampler`` to tell PyTorch
# to sample from only a subset of the dataset.
75
#
76
77
78
79
80
81
82
83
84
85

from torch.utils.data.sampler import SubsetRandomSampler

num_examples = len(dataset)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
86
87
    dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
88
test_dataloader = GraphDataLoader(
89
90
    dataset, sampler=test_sampler, batch_size=5, drop_last=False
)
91
92
93
94
95


######################################################################
# You can try to iterate over the created ``GraphDataLoader`` and see what it
# gives:
96
#
97
98
99
100
101
102
103
104
105
106
107
108

it = iter(train_dataloader)
batch = next(it)
print(batch)


######################################################################
# As each element in ``dataset`` has a graph and a label, the
# ``GraphDataLoader`` will return two objects for each iteration. The
# first element is the batched graph, and the second element is simply a
# label vector representing the category of each graph in the mini-batch.
# Next, we’ll talked about the batched graph.
109
#
110
111
# A Batched Graph in DGL
# ----------------------
112
#
113
114
115
116
117
118
119
120
121
# In each mini-batch, the sampled graphs are combined into a single bigger
# batched graph via ``dgl.batch``. The single bigger batched graph merges
# all original graphs as separately connected components, with the node
# and edge features concatenated. This bigger graph is also a ``DGLGraph``
# instance (so you can
# still treat it as a normal ``DGLGraph`` object as in
# `here <2_dglgraph.ipynb>`__). It however contains the information
# necessary for recovering the original graphs, such as the number of
# nodes and edges of each graph element.
122
#
123
124

batched_graph, labels = batch
125
126
127
128
129
130
131
132
print(
    "Number of nodes for each graph element in the batch:",
    batched_graph.batch_num_nodes(),
)
print(
    "Number of edges for each graph element in the batch:",
    batched_graph.batch_num_edges(),
)
133
134
135

# Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph)
136
print("The original graphs in the minibatch:")
137
138
139
140
141
142
print(graphs)


######################################################################
# Define Model
# ------------
143
#
144
145
146
147
148
149
# This tutorial will build a two-layer `Graph Convolutional Network
# (GCN) <http://tkipf.github.io/graph-convolutional-networks/>`__. Each of
# its layer computes new node representations by aggregating neighbor
# information. If you have gone through the
# :doc:`introduction <1_introduction>`, you will notice two
# differences:
150
#
151
152
153
154
155
156
157
158
159
160
161
# -  Since the task is to predict a single category for the *entire graph*
#    instead of for every node, you will need to aggregate the
#    representations of all the nodes and potentially the edges to form a
#    graph-level representation. Such process is more commonly referred as
#    a *readout*. A simple choice is to average the node features of a
#    graph with ``dgl.mean_nodes()``.
#
# -  The input graph to the model will be a batched graph yielded by the
#    ``GraphDataLoader``. The readout functions provided by DGL can handle
#    batched graphs so that they will return one representation for each
#    minibatch element.
162
#
163
164
165

from dgl.nn import GraphConv

166

167
168
169
170
171
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
172

173
174
175
176
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
177
178
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")
179
180
181
182
183


######################################################################
# Training Loop
# -------------
184
#
185
186
187
# The training loop iterates over the training set with the
# ``GraphDataLoader`` object and computes the gradients, just like
# image classification or language modeling.
188
#
189
190
191
192
193
194
195

# Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
196
        pred = model(batched_graph, batched_graph.ndata["attr"].float())
197
198
199
200
201
202
203
204
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
205
    pred = model(batched_graph, batched_graph.ndata["attr"].float())
206
207
208
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

209
print("Test accuracy:", num_correct / num_tests)
210
211
212
213
214


######################################################################
# What’s next
# -----------
215
#
216
217
218
# -  See `GIN
#    example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__
#    for an end-to-end graph classification model.
219
#
220

221

222
# Thumbnail credits: DGL
223
# sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png'