1_graph_classification.py 8.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
"""
Single Machine Multi-GPU Minibatch Graph Classification
=======================================================

In this tutorial, you will learn how to use multiple GPUs in training a
graph neural network (GNN) for graph classification. This tutorial assumes
knowledge in GNNs for graph classification and we recommend you to check
:doc:`Training a GNN for Graph Classification <../blitz/5_graph_classification>` otherwise.

(Time estimate: 8 minutes)

To use a single GPU in training a GNN, we need to put the model, graph(s), and other
tensors (e.g. labels) on the same GPU:

15
.. code:: python
16

17
    import torch
18

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    # Use the first GPU
    device = torch.device("cuda:0")
    model = model.to(device)
    graph = graph.to(device)
    labels = labels.to(device)

The node and edge features in the graphs, if any, will also be on the GPU.
After that, the forward computation, backward computation and parameter
update will take place on the GPU. For graph classification, this repeats
for each minibatch gradient descent.

Using multiple GPUs allows performing more computation per unit of time. It
is like having a team work together, where each GPU is a team member. We need
to distribute the computation workload across GPUs and let them synchronize
the efforts regularly. PyTorch provides convenient APIs for this task with
multiple processes, one per GPU, and we can use them in conjunction with DGL.

Intuitively, we can distribute the workload along the dimension of data. This
allows multiple GPUs to perform the forward and backward computation of
multiple gradient descents in parallel. To distribute a dataset across
multiple GPUs, we need to partition it into multiple mutually exclusive
subsets of a similar size, one per GPU. We need to repeat the random
partition every epoch to guarantee randomness. We can use
:func:`~dgl.dataloading.pytorch.GraphDataLoader`, which wraps some PyTorch 
APIs and does the job for graph classification in data loading.

Once all GPUs have finished the backward computation for its minibatch,
we need to synchronize the model parameter update across them. Specifically,
this involves collecting gradients from all GPUs, averaging them and updating
the model parameters on each GPU. We can wrap a PyTorch model with
:func:`~torch.nn.parallel.DistributedDataParallel` so that the model
parameter update will invoke gradient synchronization first under the hood.

.. image:: https://data.dgl.ai/tutorial/mgpu_gc.png
  :width: 450px
  :align: center

That’s the core behind this tutorial. We will explore it more in detail with
a complete example below.

.. note::

   See `this tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
   from PyTorch for general multi-GPU training with ``DistributedDataParallel``.

Distributed Process Group Initialization
----------------------------------------

For communication between multiple processes in multi-gpu training, we need
to start the distributed backend at the beginning of each process. We use
`world_size` to refer to the number of processes and `rank` to refer to the
process ID, which should be an integer from `0` to `world_size - 1`.
71
"""
72
73
74

import torch.distributed as dist

75

76
77
def init_process_group(world_size, rank):
    dist.init_process_group(
78
79
        backend="gloo",  # change to 'nccl' for multiple GPUs
        init_method="tcp://127.0.0.1:12345",
80
        world_size=world_size,
81
82
83
        rank=rank,
    )

84
85
86
87
88
89
90
91
92

###############################################################################
# Data Loader Preparation
# -----------------------
#
# We split the dataset into training, validation and test subsets. In dataset
# splitting, we need to use a same random seed across processes to ensure a
# same split. We follow the common practice to train with multiple GPUs and
# evaluate with a single GPU, thus only set `use_ddp` to True in the
93
# :func:`~dgl.dataloading.pytorch.GraphDataLoader` for the training set, where
94
95
96
97
98
99
# `ddp` stands for :func:`~torch.nn.parallel.DistributedDataParallel`.
#

from dgl.data import split_dataset
from dgl.dataloading import GraphDataLoader

100

101
102
def get_dataloaders(dataset, seed, batch_size=32):
    # Use a 80:10:10 train-val-test split
103
104
105
106
107
108
    train_set, val_set, test_set = split_dataset(
        dataset, frac_list=[0.8, 0.1, 0.1], shuffle=True, random_state=seed
    )
    train_loader = GraphDataLoader(
        train_set, use_ddp=True, batch_size=batch_size, shuffle=True
    )
109
110
111
112
113
    val_loader = GraphDataLoader(val_set, batch_size=batch_size)
    test_loader = GraphDataLoader(test_set, batch_size=batch_size)

    return train_loader, val_loader, test_loader

114

115
116
117
118
119
120
121
122
123
###############################################################################
# Model Initialization
# --------------------
#
# For this tutorial, we use a simplified Graph Isomorphism Network (GIN).
#

import torch.nn as nn
import torch.nn.functional as F
124

125
126
from dgl.nn.pytorch import GINConv, SumPooling

127

128
129
130
131
class GIN(nn.Module):
    def __init__(self, input_size=1, num_classes=2):
        super(GIN, self).__init__()

132
133
134
135
136
137
        self.conv1 = GINConv(
            nn.Linear(input_size, num_classes), aggregator_type="sum"
        )
        self.conv2 = GINConv(
            nn.Linear(num_classes, num_classes), aggregator_type="sum"
        )
138
139
140
141
142
143
144
145
146
        self.pool = SumPooling()

    def forward(self, g, feats):
        feats = self.conv1(g, feats)
        feats = F.relu(feats)
        feats = self.conv2(g, feats)

        return self.pool(g, feats)

147

148
149
150
151
152
153
154
155
156
###############################################################################
# To ensure same initial model parameters across processes, we need to set the
# same random seed before model initialization. Once we construct a model
# instance, we wrap it with :func:`~torch.nn.parallel.DistributedDataParallel`.
#

import torch
from torch.nn.parallel import DistributedDataParallel

157

158
159
160
def init_model(seed, device):
    torch.manual_seed(seed)
    model = GIN().to(device)
161
    if device.type == "cpu":
162
163
        model = DistributedDataParallel(model)
    else:
164
165
166
        model = DistributedDataParallel(
            model, device_ids=[device], output_device=device
        )
167
168
169

    return model

170

171
172
173
174
175
176
177
###############################################################################
# Main Function for Each Process
# -----------------------------
#
# Define the model evaluation function as in the single-GPU setting.
#

178

179
180
181
182
183
184
185
186
187
188
def evaluate(model, dataloader, device):
    model.eval()

    total = 0
    total_correct = 0

    for bg, labels in dataloader:
        bg = bg.to(device)
        labels = labels.to(device)
        # Get input node features
189
        feats = bg.ndata.pop("attr")
190
191
192
193
194
195
196
197
        with torch.no_grad():
            pred = model(bg, feats)
        _, pred = torch.max(pred, 1)
        total += len(labels)
        total_correct += (pred == labels).sum().cpu().item()

    return 1.0 * total_correct / total

198

199
200
201
202
203
204
###############################################################################
# Define the main function for each process.
#

from torch.optim import Adam

205

206
207
def main(rank, world_size, dataset, seed=0):
    init_process_group(world_size, rank)
208
    if torch.cuda.is_available():
209
        device = torch.device("cuda:{:d}".format(rank))
210
211
        torch.cuda.set_device(device)
    else:
212
        device = torch.device("cpu")
213
214
215
216
217

    model = init_model(seed, device)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.01)

218
    train_loader, val_loader, test_loader = get_dataloaders(dataset, seed)
219
220
221
222
223
224
225
226
227
228
    for epoch in range(5):
        model.train()
        # The line below ensures all processes use a different
        # random ordering in data loading for each epoch.
        train_loader.set_epoch(epoch)

        total_loss = 0
        for bg, labels in train_loader:
            bg = bg.to(device)
            labels = labels.to(device)
229
            feats = bg.ndata.pop("attr")
230
231
232
233
234
235
236
237
            pred = model(bg, feats)

            loss = criterion(pred, labels)
            total_loss += loss.cpu().item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        loss = total_loss
238
        print("Loss: {:.4f}".format(loss))
239
240

        val_acc = evaluate(model, val_loader, device)
241
        print("Val acc: {:.4f}".format(val_acc))
242
243

    test_acc = evaluate(model, test_loader, device)
244
    print("Test acc: {:.4f}".format(test_acc))
245
246
    dist.destroy_process_group()

247

248
249
###############################################################################
# Finally we load the dataset and launch the processes.
250
251
252
253
254
#
# .. code:: python
#
#    if __name__ == '__main__':
#        import torch.multiprocessing as mp
255
#
256
#        from dgl.data import GINDataset
257
#
258
259
260
261
#        num_gpus = 4
#        procs = []
#        dataset = GINDataset(name='IMDBBINARY', self_loop=False)
#        mp.spawn(main, args=(num_gpus, dataset), nprocs=num_gpus)
262
263
264

# Thumbnail credits: DGL
# sphinx_gallery_thumbnail_path = '_static/blitz_5_graph_classification.png'