train_dist_transductive.py 11.8 KB
Newer Older
1
2
import argparse
import time
3

4
import numpy as np
5
6
7
8
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
9
import dgl
10
11
from dgl.distributed import DistEmbedding
from train_dist import DistSAGE, compute_acc
12

13

14
15
16
17
18
def initializer(shape, dtype):
    arr = th.zeros(shape, dtype=dtype)
    arr.uniform_(-1, 1)
    return arr

19

20
class DistEmb(nn.Module):
21
22
23
    def __init__(
            self, num_nodes, emb_size, dgl_sparse_emb=False, dev_id="cpu"
    ):
24
25
26
27
28
        super().__init__()
        self.dev_id = dev_id
        self.emb_size = emb_size
        self.dgl_sparse_emb = dgl_sparse_emb
        if dgl_sparse_emb:
29
30
31
            self.sparse_emb = DistEmbedding(
                num_nodes, emb_size, name="sage", init_func=initializer
            )
32
33
34
35
36
37
38
39
40
41
42
43
        else:
            self.sparse_emb = th.nn.Embedding(num_nodes, emb_size, sparse=True)
            nn.init.uniform_(self.sparse_emb.weight, -1.0, 1.0)

    def forward(self, idx):
        # embeddings are stored in cpu
        idx = idx.cpu()
        if self.dgl_sparse_emb:
            return self.sparse_emb(idx, device=self.dev_id)
        else:
            return self.sparse_emb(idx).to(self.dev_id)

44

45
def load_embs(standalone, emb_layer, g):
46
    nodes = dgl.distributed.node_split(
47
        np.arange(g.num_nodes()), g.get_partition_book(), force_even=True
48
    )
49
    x = dgl.distributed.DistTensor(
50
        (
51
            g.num_nodes(),
52
53
54
55
56
57
58
59
            emb_layer.module.emb_size
            if isinstance(emb_layer, th.nn.parallel.DistributedDataParallel)
            else emb_layer.emb_size,
        ),
        th.float32,
        "eval_embs",
        persistent=True,
    )
60
61
    num_nodes = nodes.shape[0]
    for i in range((num_nodes + 1023) // 1024):
62
        idx = nodes[
63
            i * 1024: (i + 1) * 1024
64
65
66
            if (i + 1) * 1024 < num_nodes
            else num_nodes
        ]
67
68
69
70
71
72
73
74
        embeds = emb_layer(idx).cpu()
        x[idx] = embeds

    if not standalone:
        g.barrier()

    return x

75
76
77
78
79
80
81
82
83
84
85
86

def evaluate(
    standalone,
    model,
    emb_layer,
    g,
    labels,
    val_nid,
    test_nid,
    batch_size,
    device,
):
87
88
89
90
91
92
93
94
95
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_nid : the node Ids for validation.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
96
97
    if not standalone:
        model = model.module
98
99
100
101
    model.eval()
    emb_layer.eval()
    with th.no_grad():
        inputs = load_embs(standalone, emb_layer, g)
102
        pred = model.inference(g, inputs, batch_size, device)
103
104
    model.train()
    emb_layer.train()
105
106
107
108
    return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
        pred[test_nid], labels[test_nid]
    )

109
110
111
112

def run(args, device, data):
    # Unpack data
    train_nid, val_nid, test_nid, n_classes, g = data
113
114
    sampler = dgl.dataloading.NeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(",")]
115
    )
116
117
118
119
    dataloader = dgl.dataloading.DistNodeDataLoader(
        g,
        train_nid,
        sampler,
120
121
        batch_size=args.batch_size,
        shuffle=True,
122
123
        drop_last=False,
    )
124
    # Define model and optimizer
125
126
127
128
129
130
    emb_layer = DistEmb(
        g.num_nodes(),
        args.num_hidden,
        dgl_sparse_emb=args.dgl_sparse,
        dev_id=device,
    )
131
    model = DistSAGE(
132
133
134
135
136
137
138
        args.num_hidden,
        args.num_hidden,
        n_classes,
        args.num_layers,
        F.relu,
        args.dropout,
    )
139
140
141
142
143
144
    model = model.to(device)
    if not args.standalone:
        if args.num_gpus == -1:
            model = th.nn.parallel.DistributedDataParallel(model)
        else:
            dev_id = g.rank() % args.num_gpus
145
146
147
            model = th.nn.parallel.DistributedDataParallel(
                model, device_ids=[dev_id], output_device=dev_id
            )
148
149
150
151
152
153
            if not args.dgl_sparse:
                emb_layer = th.nn.parallel.DistributedDataParallel(emb_layer)
    loss_fcn = nn.CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if args.dgl_sparse:
154
155
156
157
        emb_optimizer = dgl.distributed.optim.SparseAdam(
            [emb_layer.sparse_emb], lr=args.sparse_lr
        )
        print("optimize DGL sparse embedding:", emb_layer.sparse_emb)
158
    elif args.standalone:
159
160
161
162
        emb_optimizer = th.optim.SparseAdam(
            list(emb_layer.sparse_emb.parameters()), lr=args.sparse_lr
        )
        print("optimize Pytorch sparse embedding:", emb_layer.sparse_emb)
163
    else:
164
165
166
        emb_optimizer = th.optim.SparseAdam(
            list(emb_layer.module.sparse_emb.parameters()), lr=args.sparse_lr
        )
167
168
169
170
        print(
            "optimize Pytorch sparse embedding:",
            emb_layer.module.sparse_emb
        )
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    # Training loop
    iter_tput = []
    epoch = 0
    for epoch in range(args.num_epochs):
        tic = time.time()

        sample_time = 0
        forward_time = 0
        backward_time = 0
        update_time = 0
        num_seeds = 0
        num_inputs = 0
        start = time.time()
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
        with model.join():
            # Loop over the dataloader to sample the computation dependency
            # graph as a list of blocks.
            step_time = []
            for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
                tic_step = time.time()
                sample_time += tic_step - start
                num_seeds += len(blocks[-1].dstdata[dgl.NID])
                num_inputs += len(blocks[0].srcdata[dgl.NID])
                blocks = [block.to(device) for block in blocks]
                batch_labels = g.ndata["labels"][seeds].long().to(device)
                # Compute loss and prediction
                start = time.time()
                batch_inputs = emb_layer(input_nodes)
                batch_pred = model(blocks, batch_inputs)
                loss = loss_fcn(batch_pred, batch_labels)
                forward_end = time.time()
                emb_optimizer.zero_grad()
                optimizer.zero_grad()
                loss.backward()
                compute_end = time.time()
                forward_time += forward_end - start
                backward_time += compute_end - forward_end

                emb_optimizer.step()
                optimizer.step()
                update_time += time.time() - compute_end

                step_t = time.time() - tic_step
                step_time.append(step_t)
                iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
                if step % args.log_every == 0:
                    acc = compute_acc(batch_pred, batch_labels)
                    gpu_mem_alloc = (
                        th.cuda.max_memory_allocated() / 1000000
                        if th.cuda.is_available()
                        else 0
222
                    )
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
                    print(
                        "Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
                        "Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU "
                        "{:.1f} MB | time {:.3f} s".format(
                            g.rank(),
                            epoch,
                            step,
                            loss.item(),
                            acc.item(),
                            np.mean(iter_tput[3:]),
                            gpu_mem_alloc,
                            np.sum(step_time[-args.log_every:]),
                        )
                    )
                start = time.time()
238
239

        toc = time.time()
240
        print(
241
242
243
            "Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward"
            ": {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs"
            ": {}".format(
244
245
246
247
248
249
250
251
252
253
                g.rank(),
                toc - tic,
                sample_time,
                forward_time,
                backward_time,
                update_time,
                num_seeds,
                num_inputs,
            )
        )
254
255
256
257
        epoch += 1

        if epoch % args.eval_every == 0 and epoch != 0:
            start = time.time()
258
259
            val_acc, test_acc = evaluate(
                args.standalone,
260
                model,
261
262
263
264
265
266
267
268
269
                emb_layer,
                g,
                g.ndata["labels"],
                val_nid,
                test_nid,
                args.batch_size_eval,
                device,
            )
            print(
270
271
                "Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format
                (
272
273
274
275
                    g.rank(), val_acc, test_acc, time.time() - start
                )
            )

276
277
278
279

def main(args):
    dgl.distributed.initialize(args.ip_config)
    if not args.standalone:
280
        th.distributed.init_process_group(backend="gloo")
281
282
283
284
    g = dgl.distributed.DistGraph(
            args.graph_name,
            part_config=args.part_config
        )
285
    print("rank:", g.rank())
286
287

    pb = g.get_partition_book()
288
289
290
291
292
293
294
295
296
    train_nid = dgl.distributed.node_split(
        g.ndata["train_mask"], pb, force_even=True
    )
    val_nid = dgl.distributed.node_split(
        g.ndata["val_mask"], pb, force_even=True
    )
    test_nid = dgl.distributed.node_split(
        g.ndata["test_mask"], pb, force_even=True
    )
297
    local_nid = pb.partid2nids(pb.partid).detach().numpy()
298
    print(
299
300
        "part {}, train: {} (local: {}), val: {} (local: {}), test: {} "
        "(local: {})".format(
301
302
303
304
305
306
307
308
309
            g.rank(),
            len(train_nid),
            len(np.intersect1d(train_nid.numpy(), local_nid)),
            len(val_nid),
            len(np.intersect1d(val_nid.numpy(), local_nid)),
            len(test_nid),
            len(np.intersect1d(test_nid.numpy(), local_nid)),
        )
    )
310
    if args.num_gpus == -1:
311
        device = th.device("cpu")
312
    else:
313
314
315
        dev_id = g.rank() % args.num_gpus
        device = th.device("cuda:" + str(dev_id))
    labels = g.ndata["labels"][np.arange(g.num_nodes())]
316
    n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
317
    print("#labels:", n_classes)
318
319
320
321
322
323

    # Pack data
    data = train_nid, val_nid, test_nid, n_classes, g
    run(args, device, data)
    print("parent ends")

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="GCN")
    parser.add_argument("--graph_name", type=str, help="graph name")
    parser.add_argument("--id", type=int, help="the partition id")
    parser.add_argument(
        "--ip_config", type=str, help="The file for IP configuration"
    )
    parser.add_argument(
        "--part_config", type=str, help="The path to the partition config file"
    )
    parser.add_argument("--n_classes", type=int, help="the number of classes")
    parser.add_argument(
        "--num_gpus",
        type=int,
        default=-1,
        help="the number of GPU device. Use -1 for CPU training",
    )
    parser.add_argument("--num_epochs", type=int, default=20)
    parser.add_argument("--num_hidden", type=int, default=16)
    parser.add_argument("--num_layers", type=int, default=2)
    parser.add_argument("--fan_out", type=str, default="10,25")
    parser.add_argument("--batch_size", type=int, default=1000)
    parser.add_argument("--batch_size_eval", type=int, default=100000)
    parser.add_argument("--log_every", type=int, default=20)
    parser.add_argument("--eval_every", type=int, default=5)
    parser.add_argument("--lr", type=float, default=0.003)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument(
        "--local_rank", type=int, help="get rank of the process"
    )
    parser.add_argument(
        "--standalone", action="store_true", help="run in the standalone mode"
    )
    parser.add_argument(
        "--dgl_sparse",
        action="store_true",
        help="Whether to use DGL sparse embedding",
    )
    parser.add_argument(
        "--sparse_lr", type=float, default=1e-2, help="sparse lr rate"
    )
366
367
368
369
    args = parser.parse_args()

    print(args)
    main(args)