train_dist_unsupervised.py 15 KB
Newer Older
1
2
import argparse
import time
3
from contextlib import contextmanager
4

5
6
7
8
9
10
11
import numpy as np
import sklearn.linear_model as lm
import sklearn.metrics as skm
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
12
13
14
15
16
import tqdm
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn

17
class DistSAGE(nn.Module):
18
19
20
    def __init__(
        self, in_feats, n_hidden, n_classes, n_layers, activation, dropout
    ):
21
22
23
24
25
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
26
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
27
        for i in range(1, n_layers - 1):
28
29
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
30
31
32
33
34
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
35
        for i, (layer, block) in enumerate(zip(self.layers, blocks)):
36
            h = layer(block, h)
37
            if i != len(self.layers) - 1:
38
39
40
41
42
43
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        """
44
45
        Inference with the GraphSAGE model on full neighbors (i.e. without
        neighbor sampling).
46
47
48
49

        g : the entire graph.
        x : the input of entire node set.

50
51
        The inference code is written in a fashion that it could handle any
        number of nodes and layers.
52
        """
53
54
55
56
        # During inference with sampling, multi-layer blocks are very
        # inefficient because lots of computations in the first few layers are
        # repeated. Therefore, we compute the representation of all nodes layer
        # by layer.  The nodes on each layer are of course splitted in batches.
57
        # TODO: can we standardize this?
58
        nodes = dgl.distributed.node_split(
59
            np.arange(g.num_nodes()),
60
61
62
63
            g.get_partition_book(),
            force_even=True,
        )
        y = dgl.distributed.DistTensor(
64
            (g.num_nodes(), self.n_hidden),
65
66
67
68
            th.float32,
            "h",
            persistent=True,
        )
69
70
        for i, layer in enumerate(self.layers):
            if i == len(self.layers) - 1:
71
                y = dgl.distributed.DistTensor(
72
                    (g.num_nodes(), self.n_classes),
73
74
75
76
                    th.float32,
                    "h_last",
                    persistent=True,
                )
77
78
79
80
81
82
83
            # Create sampler
            sampler = dgl.dataloading.NeighborSampler([-1])
            # Create dataloader
            dataloader = dgl.dataloading.DistNodeDataLoader(
                g,
                nodes,
                sampler,
84
85
                batch_size=batch_size,
                shuffle=False,
86
87
                drop_last=False,
            )
88

89
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
90
                block = blocks[0].to(device)
91
                h = x[input_nodes].to(device)
92
                h_dst = h[: block.number_of_dst_nodes()]
93
                h = layer(block, (h, h_dst))
94
                if i != len(self.layers) - 1:
95
96
97
98
99
100
101
102
103
                    h = self.activation(h)
                    h = self.dropout(h)

                y[output_nodes] = h.cpu()

            x = y
            g.barrier()
        return y

104
105
106
107
108
    @contextmanager
    def join(self):
        """dummy join for standalone"""
        yield

109

110
111
112
113
def load_subtensor(g, input_nodes, device):
    """
    Copys features and labels of a set of nodes onto GPU.
    """
114
    batch_inputs = g.ndata["features"][input_nodes].to(device)
115
116
    return batch_inputs

117

118
119
120
class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
121
122
123
            pos_graph.ndata["h"] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
            pos_score = pos_graph.edata["score"]
124
        with neg_graph.local_scope():
125
126
127
            neg_graph.ndata["h"] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v("h", "h", "score"))
            neg_score = neg_graph.edata["score"]
128
129

        score = th.cat([pos_score, neg_score])
130
131
132
        label = th.cat(
            [th.ones_like(pos_score), th.zeros_like(neg_score)]
        ).long()
133
134
135
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
def generate_emb(model, g, inputs, batch_size, device):
    """
    Generate embeddings for each node
    g : The entire graph.
    inputs : The features of all the nodes.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
        pred = model.inference(g, inputs, batch_size, device)

    return pred

151

152
153
154
def compute_acc(emb, labels, train_nids, val_nids, test_nids):
    """
    Compute the accuracy of prediction given the labels.
155

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    We will fist train a LogisticRegression model using the trained embeddings,
    the training set, validation set and test set is provided as the arguments.

    The final result is predicted by the lr model.

    emb: The pretrained embeddings
    labels: The ground truth
    train_nids: The training set node ids
    val_nids: The validation set node ids
    test_nids: The test set node ids
    """

    emb = emb[np.arange(labels.shape[0])].cpu().numpy()
    train_nids = train_nids.cpu().numpy()
    val_nids = val_nids.cpu().numpy()
    test_nids = test_nids.cpu().numpy()
    labels = labels.cpu().numpy()

    emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)
175
    lr = lm.LogisticRegression(multi_class="multinomial", max_iter=10000)
176
177
178
179
180
181
182
    lr.fit(emb[train_nids], labels[train_nids])

    pred = lr.predict(emb)
    eval_acc = skm.accuracy_score(labels[val_nids], pred[val_nids])
    test_acc = skm.accuracy_score(labels[test_nids], pred[test_nids])
    return eval_acc, test_acc

183

184
185
def run(args, device, data):
    # Unpack data
186
187
188
189
190
191
192
193
194
195
    (
        train_eids,
        train_nids,
        in_feats,
        g,
        global_train_nid,
        global_valid_nid,
        global_test_nid,
        labels,
    ) = data
196
    # Create sampler
197
198
199
    neg_sampler = dgl.dataloading.negative_sampler.Uniform(args.num_negs)
    sampler = dgl.dataloading.NeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(",")]
200
    )
201
202
203
204
205
206
207
208
209
210
    # Create dataloader
    exclude = "reverse_id" if args.remove_edge else None
    reverse_eids = th.arange(g.num_edges()) if args.remove_edge else None
    dataloader = dgl.dataloading.DistEdgeDataLoader(
        g,
        train_eids,
        sampler,
        negative_sampler=neg_sampler,
        exclude=exclude,
        reverse_eids=reverse_eids,
211
212
        batch_size=args.batch_size,
        shuffle=True,
213
214
        drop_last=False,
    )
215
    # Define model and optimizer
216
217
218
219
220
221
222
223
    model = DistSAGE(
        in_feats,
        args.num_hidden,
        args.num_hidden,
        args.num_layers,
        F.relu,
        args.dropout,
    )
224
225
    model = model.to(device)
    if not args.standalone:
226
227
228
229
        if args.num_gpus == -1:
            model = th.nn.parallel.DistributedDataParallel(model)
        else:
            dev_id = g.rank() % args.num_gpus
230
231
232
            model = th.nn.parallel.DistributedDataParallel(
                model, device_ids=[dev_id], output_device=dev_id
            )
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
    loss_fcn = CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    epoch = 0
    for epoch in range(args.num_epochs):
        num_seeds = 0
        num_inputs = 0

        step_time = []
        sample_t = []
        feat_copy_t = []
        forward_t = []
        backward_t = []
        update_t = []
        iter_tput = []

        start = time.time()
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
        with model.join():
            # Loop over the dataloader to sample the computation dependency
            # graph as a list of blocks.
            for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(
                dataloader
            ):
                tic_step = time.time()
                sample_t.append(tic_step - start)

                copy_t = time.time()
                pos_graph = pos_graph.to(device)
                neg_graph = neg_graph.to(device)
                blocks = [block.to(device) for block in blocks]
                batch_inputs = load_subtensor(g, input_nodes, device)
                copy_time = time.time()
                feat_copy_t.append(copy_time - copy_t)

                # Compute loss and prediction
                batch_pred = model(blocks, batch_inputs)
                loss = loss_fcn(batch_pred, pos_graph, neg_graph)
                forward_end = time.time()
                optimizer.zero_grad()
                loss.backward()
                compute_end = time.time()
                forward_t.append(forward_end - copy_time)
                backward_t.append(compute_end - forward_end)

                # Aggregate gradients in multiple nodes.
                optimizer.step()
                update_t.append(time.time() - compute_end)

                pos_edges = pos_graph.num_edges()

                step_t = time.time() - start
                step_time.append(step_t)
                iter_tput.append(pos_edges / step_t)
                num_seeds += pos_edges
                if step % args.log_every == 0:
                    print(
                        "[{}] Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed "
                        "(samples/sec) {:.4f} | time {:.3f}s | sample {:.3f} | "
                        "copy {:.3f} | forward {:.3f} | backward {:.3f} | "
                        "update {:.3f}".format(
                            g.rank(),
                            epoch,
                            step,
                            loss.item(),
                            np.mean(iter_tput[3:]),
                            np.sum(step_time[-args.log_every:]),
                            np.sum(sample_t[-args.log_every:]),
                            np.sum(feat_copy_t[-args.log_every:]),
                            np.sum(forward_t[-args.log_every:]),
                            np.sum(backward_t[-args.log_every:]),
                            np.sum(update_t[-args.log_every:]),
                        )
307
                    )
308
                start = time.time()
309

310
        print(
311
312
313
            "[{}]Epoch Time(s): {:.4f}, sample: {:.4f}, data copy: {:.4f}, "
            "forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, "
            "#inputs: {}".format(
314
315
316
317
318
319
320
321
322
323
324
                g.rank(),
                np.sum(step_time),
                np.sum(sample_t),
                np.sum(feat_copy_t),
                np.sum(forward_t),
                np.sum(backward_t),
                np.sum(update_t),
                num_seeds,
                num_inputs,
            )
        )
325
326
327
        epoch += 1

    # evaluate the embedding using LogisticRegression
328
329
330
331
332
333
334
    pred = generate_emb(
        model if args.standalone else model.module,
        g,
        g.ndata["features"],
        args.batch_size_eval,
        device,
    )
335
    if g.rank() == 0:
336
337
338
339
        eval_acc, test_acc = compute_acc(
            pred, labels, global_train_nid, global_valid_nid, global_test_nid
        )
        print("eval acc {:.4f}; test acc {:.4f}".format(eval_acc, test_acc))
340
341
342
343
344
345
346
347
348
349

    # sync for eval and test
    if not args.standalone:
        th.distributed.barrier()

    if not args.standalone:
        g._client.barrier()

        # save features into file
        if g.rank() == 0:
350
            th.save(pred, "emb.pt")
351
    else:
352
353
        th.save(pred, "emb.pt")

354
355

def main(args):
356
    dgl.distributed.initialize(args.ip_config)
357
    if not args.standalone:
358
        th.distributed.init_process_group(backend="gloo")
359
360
361
    g = dgl.distributed.DistGraph(
            args.graph_name, part_config=args.part_config
        )
362
    print("rank:", g.rank())
363
    print("number of edges", g.num_edges())
364
365

    train_eids = dgl.distributed.edge_split(
366
        th.ones((g.num_edges(),), dtype=th.bool),
367
368
369
370
        g.get_partition_book(),
        force_even=True,
    )
    train_nids = dgl.distributed.node_split(
371
        th.ones((g.num_nodes(),), dtype=th.bool), g.get_partition_book()
372
373
    )
    global_train_nid = th.LongTensor(
374
        np.nonzero(g.ndata["train_mask"][np.arange(g.num_nodes())])
375
376
    )
    global_valid_nid = th.LongTensor(
377
        np.nonzero(g.ndata["val_mask"][np.arange(g.num_nodes())])
378
379
    )
    global_test_nid = th.LongTensor(
380
        np.nonzero(g.ndata["test_mask"][np.arange(g.num_nodes())])
381
    )
382
    labels = g.ndata["labels"][np.arange(g.num_nodes())]
383
    if args.num_gpus == -1:
384
        device = th.device("cpu")
385
    else:
386
387
        dev_id = g.rank() % args.num_gpus
        device = th.device("cuda:" + str(dev_id))
388
389

    # Pack data
390
    in_feats = g.ndata["features"].shape[1]
391
392
393
394
395
396
    global_train_nid = global_train_nid.squeeze()
    global_valid_nid = global_valid_nid.squeeze()
    global_test_nid = global_test_nid.squeeze()
    print("number of train {}".format(global_train_nid.shape[0]))
    print("number of valid {}".format(global_valid_nid.shape[0]))
    print("number of test {}".format(global_test_nid.shape[0]))
397
398
399
400
401
402
403
404
405
406
    data = (
        train_eids,
        train_nids,
        in_feats,
        g,
        global_train_nid,
        global_valid_nid,
        global_test_nid,
        labels,
    )
407
408
409
    run(args, device, data)
    print("parent ends")

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GCN")
    parser.add_argument("--graph_name", type=str, help="graph name")
    parser.add_argument("--id", type=int, help="the partition id")
    parser.add_argument(
        "--ip_config", type=str, help="The file for IP configuration"
    )
    parser.add_argument(
        "--part_config", type=str, help="The path to the partition config file"
    )
    parser.add_argument("--n_classes", type=int, help="the number of classes")
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=-1,
        help="the number of GPU device. Use -1 for CPU training",
    )
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_hidden", type=int, default=16)
    parser.add_argument("--num-layers", type=int, default=2)
    parser.add_argument("--fan_out", type=str, default="10,25")
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--batch_size_eval", type=int, default=100000)
    parser.add_argument("--log_every", type=int, default=20)
    parser.add_argument("--eval_every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.003)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument(
        "--local_rank", type=int, help="get rank of the process"
    )
    parser.add_argument(
        "--standalone", action="store_true", help="run in the standalone mode"
    )
    parser.add_argument("--num_negs", type=int, default=1)
    parser.add_argument(
        "--remove_edge",
        default=False,
        action="store_true",
        help="whether to remove edges during sampling",
    )
451
452
453
    args = parser.parse_args()
    print(args)
    main(args)