"src/vscode:/vscode.git/clone" did not exist on "417927f554e748c99fa9f7d6d637934ac331ea40"
2_node_classification.py 9.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""
Single Machine Multi-GPU Minibatch Node Classification
======================================================

In this tutorial, you will learn how to use multiple GPUs in training a
graph neural network (GNN) for node classification.

(Time estimate: 8 minutes)

This tutorial assumes that you have read the :doc:`Training GNN with Neighbor
Sampling for Node Classification <../large/L1_large_node_classification>`
tutorial. It also assumes that you know the basics of training general
models with multi-GPU with ``DistributedDataParallel``.

.. note::

   See `this tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
   from PyTorch for general multi-GPU training with ``DistributedDataParallel``.  Also,
   see the first section of :doc:`the multi-GPU graph classification
   tutorial <1_graph_classification>`
   for an overview of using ``DistributedDataParallel`` with DGL.

"""


######################################################################
# Loading Dataset
# ---------------
29
#
30
31
32
33
# OGB already prepared the data as a ``DGLGraph`` object. The following code is
# copy-pasted from the :doc:`Training GNN with Neighbor Sampling for Node
# Classification <../large/L1_large_node_classification>`
# tutorial.
34
#
35
exit(0)
36
import os
37
38

os.environ["DGLBACKEND"] = "pytorch"
39
40
import dgl
import numpy as np
41
42
import sklearn.metrics
import torch
43
44
import torch.nn as nn
import torch.nn.functional as F
45
import tqdm
46
47
48
from dgl.nn import SAGEConv
from ogb.nodeproppred import DglNodePropPredDataset

49
dataset = DglNodePropPredDataset("ogbn-arxiv")
50
51
52
53

graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
54
graph.ndata["label"] = node_labels[:, 0]
55

56
node_features = graph.ndata["feat"]
57
58
59
60
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()

idx_split = dataset.get_idx_split()
61
62
63
train_nids = idx_split["train"]
valid_nids = idx_split["valid"]
test_nids = idx_split["test"]  # Test node IDs, not used in the tutorial though.
64
65
66
67
68


######################################################################
# Defining Model
# --------------
69
#
70
71
72
# The model will be again identical to the :doc:`Training GNN with Neighbor
# Sampling for Node Classification <../large/L1_large_node_classification>`
# tutorial.
73
74
#

75
76
77
78

class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
79
80
        self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
        self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean")
81
82
83
        self.h_feats = h_feats

    def forward(self, mfgs, x):
84
        h_dst = x[: mfgs[0].num_dst_nodes()]
85
86
        h = self.conv1(mfgs[0], (x, h_dst))
        h = F.relu(h)
87
        h_dst = h[: mfgs[1].num_dst_nodes()]
88
89
90
91
92
93
94
        h = self.conv2(mfgs[1], (h, h_dst))
        return h


######################################################################
# Defining Training Procedure
# ---------------------------
95
#
96
97
98
99
100
101
102
# The training procedure will be slightly different from what you saw
# previously, in the sense that you will need to
#
# * Initialize a distributed training context with ``torch.distributed``.
# * Wrap your model with ``torch.nn.parallel.DistributedDataParallel``.
# * Add a ``use_ddp=True`` argument to the DGL dataloader you wish to run
#   together with DDP.
103
#
104
105
# You will also need to wrap the training loop inside a function so that
# you can spawn subprocesses to run it.
106
107
#

108
109
110
111

def run(proc_id, devices):
    # Initialize distributed training context.
    dev_id = devices[proc_id]
112
113
114
    dist_init_method = "tcp://{master_ip}:{master_port}".format(
        master_ip="127.0.0.1", master_port="12345"
    )
115
    if torch.cuda.device_count() < 1:
116
        device = torch.device("cpu")
117
        torch.distributed.init_process_group(
118
119
120
121
122
            backend="gloo",
            init_method=dist_init_method,
            world_size=len(devices),
            rank=proc_id,
        )
123
124
    else:
        torch.cuda.set_device(dev_id)
125
        device = torch.device("cuda:" + str(dev_id))
126
        torch.distributed.init_process_group(
127
128
129
130
131
132
            backend="nccl",
            init_method=dist_init_method,
            world_size=len(devices),
            rank=proc_id,
        )

133
134
135
    # Define training and validation dataloader, copied from the previous tutorial
    # but with one line of difference: use_ddp to enable distributed data parallel
    # data loading.
136
137
    sampler = dgl.dataloading.NeighborSampler([4, 4])
    train_dataloader = dgl.dataloading.DataLoader(
138
        # The following arguments are specific to DataLoader.
139
140
141
142
143
        graph,  # The graph
        train_nids,  # The node IDs to iterate over in minibatches
        sampler,  # The neighbor sampler
        device=device,  # Put the sampled MFGs on CPU or GPU
        use_ddp=True,  # Make it work with distributed data parallel
144
        # The following arguments are inherited from PyTorch DataLoader.
145
146
147
148
149
        batch_size=1024,  # Per-device batch size.
        # The effective batch size is this number times the number of GPUs.
        shuffle=True,  # Whether to shuffle the nodes for every epoch
        drop_last=False,  # Whether to drop the last incomplete batch
        num_workers=0,  # Number of sampler processes
150
    )
151
    valid_dataloader = dgl.dataloading.DataLoader(
152
153
154
        graph,
        valid_nids,
        sampler,
155
156
157
158
159
160
161
        device=device,
        use_ddp=False,
        batch_size=1024,
        shuffle=False,
        drop_last=False,
        num_workers=0,
    )
162

163
164
    model = Model(num_features, 128, num_classes).to(device)
    # Wrap the model with distributed data parallel module.
165
166
167
168
    if device == torch.device("cpu"):
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=None, output_device=None
        )
169
    else:
170
171
172
173
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[device], output_device=device
        )

174
175
    # Define optimizer
    opt = torch.optim.Adam(model.parameters())
176

177
    best_accuracy = 0
178
179
    best_model_path = "./model.pt"

180
181
182
183
184
185
186
    # Copied from previous tutorial with changes highlighted.
    for epoch in range(10):
        model.train()

        with tqdm.tqdm(train_dataloader) as tq:
            for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
                # feature copy from CPU to GPU takes place here
187
188
                inputs = mfgs[0].srcdata["feat"]
                labels = mfgs[-1].dstdata["label"]
189
190
191
192
193
194
195
196

                predictions = model(mfgs, inputs)

                loss = F.cross_entropy(predictions, labels)
                opt.zero_grad()
                loss.backward()
                opt.step()

197
198
199
200
                accuracy = sklearn.metrics.accuracy_score(
                    labels.cpu().numpy(),
                    predictions.argmax(1).detach().cpu().numpy(),
                )
201

202
203
204
205
                tq.set_postfix(
                    {"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy},
                    refresh=False,
                )
206
207
208
209
210
211
212
213
214

        model.eval()

        # Evaluate on only the first GPU.
        if proc_id == 0:
            predictions = []
            labels = []
            with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
                for input_nodes, output_nodes, mfgs in tq:
215
216
217
218
219
                    inputs = mfgs[0].srcdata["feat"]
                    labels.append(mfgs[-1].dstdata["label"].cpu().numpy())
                    predictions.append(
                        model(mfgs, inputs).argmax(1).cpu().numpy()
                    )
220
221
222
                predictions = np.concatenate(predictions)
                labels = np.concatenate(labels)
                accuracy = sklearn.metrics.accuracy_score(labels, predictions)
223
                print("Epoch {} Validation Accuracy {}".format(epoch, accuracy))
224
225
226
227
228
229
230
231
232
233
234
                if best_accuracy < accuracy:
                    best_accuracy = accuracy
                    torch.save(model.state_dict(), best_model_path)

        # Note that this tutorial does not train the whole model to the end.
        break


######################################################################
# Spawning Trainer Processes
# --------------------------
235
#
236
237
238
239
240
241
242
243
244
245
246
247
248
# A typical scenario for multi-GPU training with DDP is to replicate the
# model once per GPU, and spawn one trainer process per GPU.
#
# Normally, DGL maintains only one sparse matrix representation (usually COO)
# for each graph, and will create new formats when some APIs are called for
# efficiency.  For instance, calling ``in_degrees`` will create a CSC
# representation for the graph, and calling ``out_degrees`` will create a
# CSR representation.  A consequence is that if a graph is shared to
# trainer processes via copy-on-write *before* having its CSC/CSR
# created, each trainer will create its own CSC/CSR replica once ``in_degrees``
# or ``out_degrees`` is called.  To avoid this, you need to create
# all sparse matrix representations beforehand using the ``create_formats_``
# method:
249
#
250
251
252
253
254
255

graph.create_formats_()


######################################################################
# Then you can spawn the subprocesses to train with multiple GPUs.
256
257
#
#
258
259
260
261
262
263
264
# .. code:: python
#
#    # Say you have four GPUs.
#    if __name__ == '__main__':
#        num_gpus = 4
#        import torch.multiprocessing as mp
#        mp.spawn(run, args=(list(range(num_gpus)),), nprocs=num_gpus)
265

266
267
# Thumbnail credits: Stanford CS224W Notes
# sphinx_gallery_thumbnail_path = '_static/blitz_1_introduction.png'