train_sampling_unsupervised.py 9.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from dgl.data import RedditDataset
13
from torch.nn.parallel import DistributedDataParallel
14
15
import tqdm

16
from utils import thread_wrapped_func
17
18
from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

36
def evaluate(model, g, nfeat, labels, train_nids, val_nids, test_nids, device):
37
38
39
40
41
42
43
44
45
46
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
47
48
        # single gpu
        if isinstance(model, SAGE):
49
            pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers)
50
51
        # multi gpu
        else:
52
            pred = model.module.inference(g, nfeat, device, args.batch_size, args.num_workers)
53
54
55
56
    model.train()
    return compute_acc(pred, labels, train_nids, val_nids, test_nids)

#### Entry point
57
def run(proc_id, n_gpus, args, devices, data):
58
    # Unpack data
59
60
61
62
63
64
65
66
67
    device = devices[proc_id]
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
                                          rank=proc_id)
68
69
70
71
    train_mask, val_mask, test_mask, n_classes, g = data
    nfeat = g.ndata.pop('feat')
    labels = g.ndata.pop('label')
    in_feats = nfeat.shape[1]
72

73
74
75
76
    train_nid = th.LongTensor(np.nonzero(train_mask)).squeeze()
    val_nid = th.LongTensor(np.nonzero(val_mask)).squeeze()
    test_nid = th.LongTensor(np.nonzero(test_mask)).squeeze()

77
    # Create PyTorch DataLoader for constructing blocks
78
    n_edges = g.num_edges()
79
    train_seeds = np.arange(n_edges)
80
81
82
83
84
85
86
    if n_gpus > 0:
        num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
        train_seeds = train_seeds[proc_id * num_per_gpu :
                                  (proc_id + 1) * num_per_gpu \
                                  if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
                                  else train_seeds.shape[0]]

87
88
89
90
91
92
93
94
95
    # Create sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(',')])
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler, exclude='reverse_id',
        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
        reverse_eids=th.cat([
            th.arange(n_edges // 2, n_edges),
            th.arange(0, n_edges // 2)]),
96
        negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share),
97
98
99
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
100
        pin_memory=True,
101
102
103
104
105
        num_workers=args.num_workers)

    # Define model and optimizer
    model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
106
107
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[device], output_device=device)
108
109
110
111
112
    loss_fcn = CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
113
114
115
116
    iter_pos = []
    iter_neg = []
    iter_d = []
    iter_t = []
117
118
119
120
121
122
123
124
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.

125
        tic_step = time.time()
126
        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
127
            batch_inputs = nfeat[input_nodes].to(device)
128
            d_step = time.time()
129

130
131
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)
132
            blocks = [block.int().to(device) for block in blocks]
133
134
135
136
137
138
139
            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, pos_graph, neg_graph)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

140
            t = time.time()
141
142
            pos_edges = pos_graph.num_edges()
            neg_edges = neg_graph.num_edges()
143
144
145
146
            iter_pos.append(pos_edges / (t - tic_step))
            iter_neg.append(neg_edges / (t - tic_step))
            iter_d.append(d_step - tic_step)
            iter_t.append(t - d_step)
147
148
            if step % args.log_every == 0:
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
149
                print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MB'.format(
150
151
                    proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
            tic_step = time.time()
152

153
            if step % args.eval_every == 0 and proc_id == 0:
154
                eval_acc, test_acc = evaluate(model, g, nfeat, labels, train_nid, val_nid, test_nid, device)
155
156
157
158
159
                print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    best_test_acc = test_acc
                print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
160
161
162
163
164
        toc = time.time()
        if proc_id == 0:
            print('Epoch Time(s): {:.4f}'.format(toc - tic))
        if epoch >= 5:
            avg += toc - tic
165
166
        if n_gpus > 1:
            th.distributed.barrier()
167
168
169

    if proc_id == 0:
        print('Avg epoch time: {}'.format(avg / (epoch - 4)))
170

171
172
def main(args, devices):
    # load reddit data
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
173
    data = RedditDataset(self_loop=False)
Xiangkun Hu's avatar
Xiangkun Hu committed
174
175
176
177
    n_classes = data.num_classes
    g = data[0]
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
178
    test_mask = g.ndata['test_mask']
179
180
181

    # Create csr/coo/csc formats before launching training processes with multi-gpu.
    # This avoids creating certain formats in each sub-process, which saves momory and CPU.
182
    g.create_formats_()
183
    # Pack data
184
    data = train_mask, val_mask, test_mask, n_classes, g
185
186
187
188

    n_gpus = len(devices)
    if devices[0] == -1:
        run(0, 0, args, ['cpu'], data)
189
    elif n_gpus == 1:
190
191
192
193
194
195
196
197
198
199
200
201
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
            p = mp.Process(target=thread_wrapped_func(run),
                           args=(proc_id, n_gpus, args, devices, data))
            p.start()
            procs.append(p)
        for p in procs:
            p.join()


202
203
if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
204
    argparser.add_argument("--gpu", type=str, default='0',
205
                           help="GPU, can be a list of gpus for multi-gpu training,"
206
                                " e.g., 0,1,2,3; -1 for CPU")
207
208
209
210
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--num-negs', type=int, default=1)
211
    argparser.add_argument('--neg-share', default=False, action='store_true',
212
                           help="sharing neg nodes for positive nodes")
213
214
215
216
217
218
219
    argparser.add_argument('--fan-out', type=str, default='10,25')
    argparser.add_argument('--batch-size', type=int, default=10000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=1000)
    argparser.add_argument('--lr', type=float, default=0.003)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
220
                           help="Number of sampling processes. Use 0 for no extra process.")
221
    args = argparser.parse_args()
222

223
    devices = list(map(int, args.gpu.split(',')))
224

225
    main(args, devices)