train_sampling_unsupervised.py 10.7 KB
Newer Older
1
import os
2
3
4
5
6
7
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
8
import torch.multiprocessing as mp
9
10
11
12
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
13
from torch.nn.parallel import DistributedDataParallel
14

15
16
from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler
17
import sys
18
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
19
from load_graph import load_reddit, load_ogb
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

37
def evaluate(model, g, nfeat, labels, train_nids, val_nids, test_nids, device, args):
38
39
40
41
42
43
44
45
46
47
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
48
49
        # single gpu
        if isinstance(model, SAGE):
50
            pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers)
51
52
        # multi gpu
        else:
53
            pred = model.module.inference(g, nfeat, device, args.batch_size, args.num_workers)
54
55
56
57
    model.train()
    return compute_acc(pred, labels, train_nids, val_nids, test_nids)

#### Entry point
58
def run(proc_id, n_gpus, args, devices, data):
59
    # Unpack data
60
61
62
    device = th.device(devices[proc_id])
    if n_gpus > 0:
        th.cuda.set_device(device)
63
64
65
66
67
68
69
70
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
                                          rank=proc_id)
71
72
73
    train_nid, val_nid, test_nid, n_classes, g, nfeat, labels = data

    if args.data_device == 'gpu':
74
        nfeat = nfeat.to(device)
75
76
    elif args.data_device == 'uva':
        nfeat = dgl.contrib.UnifiedTensor(nfeat, device=device)
77
    in_feats = nfeat.shape[1]
78
79

    # Create PyTorch DataLoader for constructing blocks
80
    n_edges = g.num_edges()
81
    train_seeds = th.arange(n_edges)
82

83
    if args.graph_device == 'gpu':
84
85
        train_seeds = train_seeds.to(device)
        g = g.to(device)
86
87
88
89
90
        args.num_workers = 0
    elif args.graph_device == 'uva':
        train_seeds = train_seeds.to(device)
        g.pin_memory_()
        args.num_workers = 0
91

92
    # Create sampler
93
    sampler = dgl.dataloading.NeighborSampler(
94
        [int(fanout) for fanout in args.fan_out.split(',')])
95
96
    sampler = dgl.dataloading.as_edge_prediction_sampler(
        sampler, exclude='reverse_id',
97
98
99
        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
        reverse_eids=th.cat([
            th.arange(n_edges // 2, n_edges),
100
            th.arange(0, n_edges // 2)]).to(train_seeds),
101
        negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share,
102
103
104
                                         device if args.graph_device == 'uva' else None))
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler,
105
        device=device,
106
        use_ddp=n_gpus > 1,
107
108
109
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
110
111
        num_workers=args.num_workers,
        use_uva=args.graph_device == 'uva')
112
113
114
115

    # Define model and optimizer
    model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
116
117
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[device], output_device=device)
118
119
120
121
122
    loss_fcn = CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
123
124
125
126
    iter_pos = []
    iter_neg = []
    iter_d = []
    iter_t = []
127
128
129
130
131
132
133
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.
134
        tic_step = time.time()
135
        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
136
            input_nodes = input_nodes.to(nfeat.device)
137
            batch_inputs = nfeat[input_nodes].to(device)
138
            blocks = [block.int() for block in blocks]
139
140
            d_step = time.time()

141
142
143
144
145
146
147
            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, pos_graph, neg_graph)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

148
            t = time.time()
149
150
            pos_edges = pos_graph.num_edges()
            neg_edges = neg_graph.num_edges()
151
152
153
154
            iter_pos.append(pos_edges / (t - tic_step))
            iter_neg.append(neg_edges / (t - tic_step))
            iter_d.append(d_step - tic_step)
            iter_t.append(t - d_step)
155
            if step % args.log_every == 0 and proc_id == 0:
156
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
157
                print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MB'.format(
158
159
                    proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
            tic_step = time.time()
160

161
162
163
164
165
166
167
        toc = time.time()
        if proc_id == 0:
            print('Epoch Time(s): {:.4f}'.format(toc - tic))
            if epoch >= 5:
                avg += toc - tic
            if (epoch + 1) % args.eval_every == 0:
                eval_acc, test_acc = evaluate(model, g, nfeat, labels, train_nid, val_nid, test_nid, device, args)
168
169
170
171
172
                print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    best_test_acc = test_acc
                print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
173

174
175
        if n_gpus > 1:
            th.distributed.barrier()
176
177
178

    if proc_id == 0:
        print('Avg epoch time: {}'.format(avg / (epoch - 4)))
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def main(args):
    devices = list(map(int, args.gpu.split(',')))
    n_gpus = len(devices)

    # load dataset
    if args.dataset == 'reddit':
        g, n_classes = load_reddit(self_loop=False)
    elif args.dataset == 'ogbn-products':
        g, n_classes = load_ogb('ogbn-products')
    else:
        raise Exception('unknown dataset')

    train_nid = g.ndata.pop('train_mask').nonzero().squeeze()
    val_nid = g.ndata.pop('val_mask').nonzero().squeeze()
    test_nid = g.ndata.pop('test_mask').nonzero().squeeze()

    nfeat = g.ndata.pop('features')
    labels = g.ndata.pop('labels')
198
199

    # Create csr/coo/csc formats before launching training processes with multi-gpu.
200
    # This avoids creating certain formats in each sub-process, which saves memory and CPU.
201
    g.create_formats_()
202
203
204
205

    # this to avoid competition overhead on machines with many cores.
    # Change it to a proper number on your machine, especially for multi-GPU training.
    os.environ['OMP_NUM_THREADS'] = str(mp.cpu_count() // 2 // n_gpus)
206

207
    # Pack data
208
    data = train_nid, val_nid, test_nid, n_classes, g, nfeat, labels
209
210

    if devices[0] == -1:
211
212
213
214
        assert args.graph_device == 'cpu', \
               f"Must have GPUs to enable {args.graph_device} sampling."
        assert args.data_device == 'cpu', \
               f"Must have GPUs to enable {args.data_device} feature storage."
215
        run(0, 0, args, ['cpu'], data)
216
    elif n_gpus == 1:
217
218
        run(0, n_gpus, args, devices, data)
    else:
219
        mp.spawn(run, args=(n_gpus, args, devices, data), nprocs=n_gpus)
220
221


222
223
if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
224
    argparser.add_argument("--gpu", type=str, default='0',
225
                           help="GPU, can be a list of gpus for multi-gpu training,"
226
                                " e.g., 0,1,2,3; -1 for CPU")
227
228
    argparser.add_argument('--dataset', type=str, default='reddit',
                           choices=('reddit', 'ogbn-products'))
229
230
231
232
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--num-negs', type=int, default=1)
233
    argparser.add_argument('--neg-share', default=False, action='store_true',
234
                           help="sharing neg nodes for positive nodes")
235
236
237
    argparser.add_argument('--fan-out', type=str, default='10,25')
    argparser.add_argument('--batch-size', type=int, default=10000)
    argparser.add_argument('--log-every', type=int, default=20)
238
    argparser.add_argument('--eval-every', type=int, default=5)
239
240
241
    argparser.add_argument('--lr', type=float, default=0.003)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
242
                           help="Number of sampling processes. Use 0 for no extra process.")
243
244
245
246
    argparser.add_argument('--graph-device', choices=('cpu', 'gpu', 'uva'), default='cpu',
                           help="Device to perform the sampling. "
                                "Must have 0 workers for 'gpu' and 'uva'")
    argparser.add_argument('--data-device', choices=('cpu', 'gpu', 'uva'), default='gpu',
247
248
249
                           help="By default the script puts all node features and labels "
                                "on GPU when using it to save time for data copy. This may "
                                "be undesired if they cannot fit in GPU memory at once. "
250
251
252
                                "Use 'cpu' to keep the features on host memory and "
                                "'uva' to enable UnifiedTensor (GPU zero-copy access on "
                                "pinned host memory).")
253
    args = argparser.parse_args()
254

255
    main(args)