train_sampling_unsupervised.py 11.2 KB
Newer Older
1
import os
2
3
4
5
6
7
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
8
import dgl.multiprocessing as mp
9
10
11
12
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
13
from torch.nn.parallel import DistributedDataParallel
14
15
import tqdm

16
17
from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler
18
from load_graph import load_reddit, load_ogb
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

36
def evaluate(model, g, nfeat, labels, train_nids, val_nids, test_nids, device):
37
38
39
40
41
42
43
44
45
46
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
47
48
        # single gpu
        if isinstance(model, SAGE):
49
            pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers)
50
51
        # multi gpu
        else:
52
            pred = model.module.inference(g, nfeat, device, args.batch_size, args.num_workers)
53
54
55
56
    model.train()
    return compute_acc(pred, labels, train_nids, val_nids, test_nids)

#### Entry point
57
def run(proc_id, n_gpus, args, devices, data):
58
    # Unpack data
59
60
61
    device = th.device(devices[proc_id])
    if n_gpus > 0:
        th.cuda.set_device(device)
62
63
64
65
66
67
68
69
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
                                          rank=proc_id)
70
71
72
    train_nid, val_nid, test_nid, n_classes, g, nfeat, labels = data

    if args.data_device == 'gpu':
73
74
        nfeat = nfeat.to(device)
        labels = labels.to(device)
75
76
77
    elif args.data_device == 'uva':
        nfeat = dgl.contrib.UnifiedTensor(nfeat, device=device)
        labels = dgl.contrib.UnifiedTensor(labels, device=device)
78
    in_feats = nfeat.shape[1]
79
80

    # Create PyTorch DataLoader for constructing blocks
81
    n_edges = g.num_edges()
82
    train_seeds = th.arange(n_edges)
83

84
    if args.graph_device == 'gpu':
85
86
        train_seeds = train_seeds.to(device)
        g = g.to(device)
87
88
89
90
91
        args.num_workers = 0
    elif args.graph_device == 'uva':
        train_seeds = train_seeds.to(device)
        g.pin_memory_()
        args.num_workers = 0
92

93
94
95
96
97
98
99
100
    # Create sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(',')])
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler, exclude='reverse_id',
        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
        reverse_eids=th.cat([
            th.arange(n_edges // 2, n_edges),
101
            th.arange(0, n_edges // 2)]).to(train_seeds),
102
103
        negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share,
                                         device if args.graph_device == 'uva' else None),
104
        device=device,
105
        use_ddp=n_gpus > 1,
106
107
108
109
110
111
112
113
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=args.num_workers)

    # Define model and optimizer
    model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
114
115
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[device], output_device=device)
116
117
118
119
120
    loss_fcn = CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
121
122
123
124
    iter_pos = []
    iter_neg = []
    iter_d = []
    iter_t = []
125
126
127
128
129
130
131
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
132
        tic_step = time.time()
133
        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
134
            batch_inputs = nfeat[input_nodes].to(device)
135
136
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)
137
            blocks = [block.int().to(device) for block in blocks]
138
139
            d_step = time.time()

140
141
142
143
144
145
146
            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, pos_graph, neg_graph)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

147
            t = time.time()
148
149
            pos_edges = pos_graph.num_edges()
            neg_edges = neg_graph.num_edges()
150
151
152
153
            iter_pos.append(pos_edges / (t - tic_step))
            iter_neg.append(neg_edges / (t - tic_step))
            iter_d.append(d_step - tic_step)
            iter_t.append(t - d_step)
154
            if step % args.log_every == 0 and proc_id == 0:
155
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
156
                print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MB'.format(
157
158
                    proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
            tic_step = time.time()
159

160
            if step % args.eval_every == 0 and proc_id == 0:
161
                eval_acc, test_acc = evaluate(model, g, nfeat, labels, train_nid, val_nid, test_nid, device)
162
163
164
165
166
                print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    best_test_acc = test_acc
                print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
167
168
169
170
171
        toc = time.time()
        if proc_id == 0:
            print('Epoch Time(s): {:.4f}'.format(toc - tic))
        if epoch >= 5:
            avg += toc - tic
172
173
        if n_gpus > 1:
            th.distributed.barrier()
174
175
176

    if proc_id == 0:
        print('Avg epoch time: {}'.format(avg / (epoch - 4)))
177

178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def main(args):
    devices = list(map(int, args.gpu.split(',')))
    n_gpus = len(devices)

    # load dataset
    if args.dataset == 'reddit':
        g, n_classes = load_reddit(self_loop=False)
    elif args.dataset == 'ogbn-products':
        g, n_classes = load_ogb('ogbn-products')
    else:
        raise Exception('unknown dataset')

    train_nid = g.ndata.pop('train_mask').nonzero().squeeze()
    val_nid = g.ndata.pop('val_mask').nonzero().squeeze()
    test_nid = g.ndata.pop('test_mask').nonzero().squeeze()

    nfeat = g.ndata.pop('features')
    labels = g.ndata.pop('labels')
196
197

    # Create csr/coo/csc formats before launching training processes with multi-gpu.
198
    # This avoids creating certain formats in each sub-process, which saves memory and CPU.
199
    g.create_formats_()
200
201
202
203
204
205
206
207
208
209
210
211
212

    # this to avoid competition overhead on machines with many cores.
    # Change it to a proper number on your machine, especially for multi-GPU training.
    os.environ['OMP_NUM_THREADS'] = str(mp.cpu_count() // 2 // n_gpus)
    if n_gpus > 1:
        # Copy the graph to shared memory explicitly before pinning.
        # In other cases, we can just rely on fork's copy-on-write.
        # TODO: the original graph g is not freed.
        if args.graph_device == 'uva':
            g = g.shared_memory('g')
        if args.data_device == 'uva':
            nfeat = nfeat.share_memory_()
            labels = labels.share_memory_()
213
    # Pack data
214
    data = train_nid, val_nid, test_nid, n_classes, g, nfeat, labels
215
216

    if devices[0] == -1:
217
218
219
220
        assert args.graph_device == 'cpu', \
               f"Must have GPUs to enable {args.graph_device} sampling."
        assert args.data_device == 'cpu', \
               f"Must have GPUs to enable {args.data_device} feature storage."
221
        run(0, 0, args, ['cpu'], data)
222
    elif n_gpus == 1:
223
224
225
226
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
227
            p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
228
229
230
231
232
233
            p.start()
            procs.append(p)
        for p in procs:
            p.join()


234
235
if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
236
    argparser.add_argument("--gpu", type=str, default='0',
237
                           help="GPU, can be a list of gpus for multi-gpu training,"
238
                                " e.g., 0,1,2,3; -1 for CPU")
239
240
    argparser.add_argument('--dataset', type=str, default='reddit',
                           choices=('reddit', 'ogbn-products'))
241
242
243
244
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--num-negs', type=int, default=1)
245
    argparser.add_argument('--neg-share', default=False, action='store_true',
246
                           help="sharing neg nodes for positive nodes")
247
248
249
250
251
252
253
    argparser.add_argument('--fan-out', type=str, default='10,25')
    argparser.add_argument('--batch-size', type=int, default=10000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=1000)
    argparser.add_argument('--lr', type=float, default=0.003)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
254
                           help="Number of sampling processes. Use 0 for no extra process.")
255
256
257
258
    argparser.add_argument('--graph-device', choices=('cpu', 'gpu', 'uva'), default='cpu',
                           help="Device to perform the sampling. "
                                "Must have 0 workers for 'gpu' and 'uva'")
    argparser.add_argument('--data-device', choices=('cpu', 'gpu', 'uva'), default='gpu',
259
260
261
                           help="By default the script puts all node features and labels "
                                "on GPU when using it to save time for data copy. This may "
                                "be undesired if they cannot fit in GPU memory at once. "
262
263
264
                                "Use 'cpu' to keep the features on host memory and "
                                "'uva' to enable UnifiedTensor (GPU zero-copy access on "
                                "pinned host memory).")
265
    args = argparser.parse_args()
266

267
    main(args)