main.py 9.84 KB
Newer Older
1
2
3
4
5
import argparse
import time
import traceback
from functools import partial

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
8
9
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn

10
11
import numpy as np
import torch as th
12
import torch.multiprocessing as mp
13
14
15
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
16
import tqdm
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
17
from dgl.data import RedditDataset
18
19
from ogb.nodeproppred import DglNodePropPredDataset
from sampler import ClusterIter, subgraph_collate_fn
20
from torch.utils.data import DataLoader
21

22
23
#### Neighbor sampler

24

25
class SAGE(nn.Module):
26
27
28
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
29
30
31
32
33
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
34
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
35
        for i in range(1, n_layers - 1):
36
37
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, g, x):
        h = x
        for l, conv in enumerate(self.layers):
            h = conv(g, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.
        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        h = x
        for l, conv in enumerate(self.layers):
            h = conv(g, h)
            if l != len(self.layers) - 1:
                h = self.activation(h)

        return h

71

72
73
74
75
76
77
def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

78

79
80
81
82
83
84
85
86
87
88
89
90
def evaluate(model, g, labels, val_nid, test_nid, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
91
        inputs = g.ndata["feat"]
92
93
94
        model = model.cpu()
        pred = model.inference(g, inputs, batch_size, device)
    model.train()
95
96
97
98
99
100
    return (
        compute_acc(pred[val_nid], labels[val_nid]),
        compute_acc(pred[test_nid], labels[test_nid]),
        pred,
    )

101
102
103
104
105

def load_subtensor(g, labels, seeds, input_nodes, device):
    """
    Copys features and labels of a set of nodes onto GPU.
    """
106
    batch_inputs = g.ndata["feat"][input_nodes].to(device)
107
108
109
    batch_labels = labels[seeds].to(device)
    return batch_inputs, batch_labels

110

111
112
113
#### Entry point
def run(args, device, data):
    # Unpack data
114
115
116
117
118
119
120
121
122
123
    (
        train_nid,
        val_nid,
        test_nid,
        in_feats,
        labels,
        n_classes,
        g,
        cluster_iterator,
    ) = data
124
125

    # Define model and optimizer
126
127
128
129
130
131
132
133
    model = SAGE(
        in_feats,
        args.num_hidden,
        n_classes,
        args.num_layers,
        F.relu,
        args.dropout,
    )
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    model = model.to(device)
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

    # Training loop
    avg = 0
    iter_tput = []
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        iter_load = 0
        iter_far = 0
        iter_back = 0
        iter_tl = 0
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
        tic_start = time.time()
        for step, cluster in enumerate(cluster_iterator):
155
            cluster = cluster.int().to(device)
156
            mask = cluster.ndata["train_mask"].to(device)
157
158
            if mask.sum() == 0:
                continue
159
160
            feat = cluster.ndata["feat"].to(device)
            batch_labels = cluster.ndata["labels"].to(device)
161
162
163
164
165
166
167
168
169
170
171
            tic_step = time.time()

            batch_pred = model(cluster, feat)
            batch_pred = batch_pred[mask]
            batch_labels = batch_labels[mask]
            loss = loss_fcn(batch_pred, batch_labels)
            optimizer.zero_grad()
            tic_far = time.time()
            loss.backward()
            optimizer.step()
            tic_back = time.time()
172
173
174
            iter_load += tic_step - tic_start
            iter_far += tic_far - tic_step
            iter_back += tic_back - tic_far
175
176
177
178

            tic_start = time.time()
            if step % args.log_every == 0:
                acc = compute_acc(batch_pred, batch_labels)
179
180
181
182
183
184
185
186
187
188
                gpu_mem_alloc = (
                    th.cuda.max_memory_allocated() / 1000000
                    if th.cuda.is_available()
                    else 0
                )
                print(
                    "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | GPU {:.1f} MB".format(
                        epoch, step, loss.item(), acc.item(), gpu_mem_alloc
                    )
                )
189
190

        toc = time.time()
191
192
193
194
195
        print(
            "Epoch Time(s): {:.4f} Load {:.4f} Forward {:.4f} Backward {:.4f}".format(
                toc - tic, iter_load, iter_far, iter_back
            )
        )
196
197
198
199
        if epoch >= 5:
            avg += toc - tic

        if epoch % args.eval_every == 0 and epoch != 0:
200
201
202
            eval_acc, test_acc, pred = evaluate(
                model, g, labels, val_nid, test_nid, args.val_batch_size, device
            )
203
204
            model = model.to(device)
            if args.save_pred:
205
206
207
208
209
210
                np.savetxt(
                    args.save_pred + "%02d" % epoch,
                    pred.argmax(1).cpu().numpy(),
                    "%d",
                )
            print("Eval Acc {:.4f}".format(eval_acc))
211
212
213
            if eval_acc > best_eval_acc:
                best_eval_acc = eval_acc
                best_test_acc = test_acc
214
215
216
217
218
219
            print(
                "Best Eval Acc {:.4f} Test Acc {:.4f}".format(
                    best_eval_acc, best_test_acc
                )
            )
    print("Avg epoch time: {}".format(avg / (epoch - 4)))
220
221
    return best_test_acc

222
223

if __name__ == "__main__":
224
    argparser = argparse.ArgumentParser("multi-gpu training")
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    argparser.add_argument(
        "--gpu",
        type=int,
        default=0,
        help="GPU device ID. Use -1 for CPU training",
    )
    argparser.add_argument("--num-epochs", type=int, default=30)
    argparser.add_argument("--num-hidden", type=int, default=256)
    argparser.add_argument("--num-layers", type=int, default=3)
    argparser.add_argument("--batch-size", type=int, default=32)
    argparser.add_argument("--val-batch-size", type=int, default=10000)
    argparser.add_argument("--log-every", type=int, default=20)
    argparser.add_argument("--eval-every", type=int, default=1)
    argparser.add_argument("--lr", type=float, default=0.001)
    argparser.add_argument("--dropout", type=float, default=0.5)
    argparser.add_argument("--save-pred", type=str, default="")
    argparser.add_argument("--wd", type=float, default=0)
    argparser.add_argument("--num_partitions", type=int, default=15000)
243
    args = argparser.parse_args()
244

245
    if args.gpu >= 0:
246
        device = th.device("cuda:%d" % args.gpu)
247
    else:
248
        device = th.device("cpu")
249

250
    # load ogbn-products data
251
    data = DglNodePropPredDataset(name="ogbn-products")
252
    splitted_idx = data.get_idx_split()
253
254
255
256
257
    train_idx, val_idx, test_idx = (
        splitted_idx["train"],
        splitted_idx["valid"],
        splitted_idx["test"],
    )
258
259
260
    graph, labels = data[0]
    labels = labels[:, 0]
    num_nodes = train_idx.shape[0] + val_idx.shape[0] + test_idx.shape[0]
261
    assert num_nodes == graph.num_nodes()
262
    graph.ndata["labels"] = labels
263
264
    mask = th.zeros(num_nodes, dtype=th.bool)
    mask[train_idx] = True
265
    graph.ndata["train_mask"] = mask
266
267
    mask = th.zeros(num_nodes, dtype=th.bool)
    mask[val_idx] = True
268
    graph.ndata["valid_mask"] = mask
269
270
    mask = th.zeros(num_nodes, dtype=th.bool)
    mask[test_idx] = True
271
    graph.ndata["test_mask"] = mask
272

273
274
    graph.in_degrees(0)
    graph.out_degrees(0)
275
276
277
    graph.find_edges(0)

    cluster_iter_data = ClusterIter(
278
279
280
281
282
283
        "ogbn-products",
        graph,
        args.num_partitions,
        args.batch_size,
        th.cat([train_idx, val_idx, test_idx]),
    )
284
    idx = th.arange(args.num_partitions // args.batch_size)
285
286
287
288
289
290
291
292
    cluster_iterator = DataLoader(
        cluster_iter_data,
        batch_size=32,
        shuffle=True,
        pin_memory=True,
        num_workers=4,
        collate_fn=partial(subgraph_collate_fn, graph),
    )
293

294
    in_feats = graph.ndata["feat"].shape[1]
295
296
297
    print(in_feats)
    n_classes = (labels.max() + 1).item()
    # Pack data
298
299
300
301
302
303
304
305
306
307
    data = (
        train_idx,
        val_idx,
        test_idx,
        in_feats,
        labels,
        n_classes,
        graph,
        cluster_iterator,
    )
308
309
310
311
312

    # Run 10 times
    test_accs = []
    for i in range(10):
        test_accs.append(run(args, device, data))
313
314
315
        print(
            "Average test accuracy:", np.mean(test_accs), "±", np.std(test_accs)
        )