train_sampling_unsupervised.py 8.77 KB
Newer Older
1
2
3
4
5
6
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
7
import dgl.multiprocessing as mp
8
9
10
11
12
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from dgl.data import RedditDataset
13
from torch.nn.parallel import DistributedDataParallel
14
15
import tqdm

16
17
from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

35
def evaluate(model, g, nfeat, labels, train_nids, val_nids, test_nids, device):
36
37
38
39
40
41
42
43
44
45
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
46
47
        # single gpu
        if isinstance(model, SAGE):
48
            pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers)
49
50
        # multi gpu
        else:
51
            pred = model.module.inference(g, nfeat, device, args.batch_size, args.num_workers)
52
53
54
55
    model.train()
    return compute_acc(pred, labels, train_nids, val_nids, test_nids)

#### Entry point
56
def run(proc_id, n_gpus, args, devices, data):
57
    # Unpack data
58
59
60
61
62
63
64
65
66
    device = devices[proc_id]
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
                                          rank=proc_id)
67
68
69
70
    train_mask, val_mask, test_mask, n_classes, g = data
    nfeat = g.ndata.pop('feat')
    labels = g.ndata.pop('label')
    in_feats = nfeat.shape[1]
71

72
73
74
75
    train_nid = th.LongTensor(np.nonzero(train_mask)).squeeze()
    val_nid = th.LongTensor(np.nonzero(val_mask)).squeeze()
    test_nid = th.LongTensor(np.nonzero(test_mask)).squeeze()

76
    # Create PyTorch DataLoader for constructing blocks
77
    n_edges = g.num_edges()
78
    train_seeds = th.arange(n_edges)
79

80
81
82
83
84
85
86
87
    # Create sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(',')])
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler, exclude='reverse_id',
        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
        reverse_eids=th.cat([
            th.arange(n_edges // 2, n_edges),
88
            th.arange(0, n_edges // 2)]).to(train_seeds),
89
        negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share),
90
        device=device,
91
        use_ddp=n_gpus > 1,
92
93
94
95
96
97
98
99
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=args.num_workers)

    # Define model and optimizer
    model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
100
101
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[device], output_device=device)
102
103
104
105
106
    loss_fcn = CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
107
108
109
110
    iter_pos = []
    iter_neg = []
    iter_d = []
    iter_t = []
111
112
113
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
114
115
        if n_gpus > 1:
            dataloader.set_epoch(epoch)
116
117
118
119
120
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.

121
        tic_step = time.time()
122
        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
123
            batch_inputs = nfeat[input_nodes].to(device)
124
            d_step = time.time()
125

126
127
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)
128
            blocks = [block.int().to(device) for block in blocks]
129
130
131
132
133
134
135
            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, pos_graph, neg_graph)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

136
            t = time.time()
137
138
            pos_edges = pos_graph.num_edges()
            neg_edges = neg_graph.num_edges()
139
140
141
142
            iter_pos.append(pos_edges / (t - tic_step))
            iter_neg.append(neg_edges / (t - tic_step))
            iter_d.append(d_step - tic_step)
            iter_t.append(t - d_step)
143
144
            if step % args.log_every == 0:
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
145
                print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MB'.format(
146
147
                    proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
            tic_step = time.time()
148

149
            if step % args.eval_every == 0 and proc_id == 0:
150
                eval_acc, test_acc = evaluate(model, g, nfeat, labels, train_nid, val_nid, test_nid, device)
151
152
153
154
155
                print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    best_test_acc = test_acc
                print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
156
157
158
159
160
        toc = time.time()
        if proc_id == 0:
            print('Epoch Time(s): {:.4f}'.format(toc - tic))
        if epoch >= 5:
            avg += toc - tic
161
162
        if n_gpus > 1:
            th.distributed.barrier()
163
164
165

    if proc_id == 0:
        print('Avg epoch time: {}'.format(avg / (epoch - 4)))
166

167
168
def main(args, devices):
    # load reddit data
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
169
    data = RedditDataset(self_loop=False)
Xiangkun Hu's avatar
Xiangkun Hu committed
170
171
172
173
    n_classes = data.num_classes
    g = data[0]
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
174
    test_mask = g.ndata['test_mask']
175
176

    # Create csr/coo/csc formats before launching training processes with multi-gpu.
177
    # This avoids creating certain formats in each sub-process, which saves memory and CPU.
178
    g.create_formats_()
179
    # Pack data
180
    data = train_mask, val_mask, test_mask, n_classes, g
181
182
183
184

    n_gpus = len(devices)
    if devices[0] == -1:
        run(0, 0, args, ['cpu'], data)
185
    elif n_gpus == 1:
186
187
188
189
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
190
            p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
191
192
193
194
195
196
            p.start()
            procs.append(p)
        for p in procs:
            p.join()


197
198
if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
199
    argparser.add_argument("--gpu", type=str, default='0',
200
                           help="GPU, can be a list of gpus for multi-gpu training,"
201
                                " e.g., 0,1,2,3; -1 for CPU")
202
203
204
205
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--num-negs', type=int, default=1)
206
    argparser.add_argument('--neg-share', default=False, action='store_true',
207
                           help="sharing neg nodes for positive nodes")
208
209
210
211
212
213
214
    argparser.add_argument('--fan-out', type=str, default='10,25')
    argparser.add_argument('--batch-size', type=int, default=10000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=1000)
    argparser.add_argument('--lr', type=float, default=0.003)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
215
                           help="Number of sampling processes. Use 0 for no extra process.")
216
    args = argparser.parse_args()
217

218
    devices = list(map(int, args.gpu.split(',')))
219

220
    main(args, devices)