train_sampling_unsupervised.py 9.56 KB
Newer Older
1
2
3
4
5
6
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
7
import dgl.multiprocessing as mp
8
9
10
11
12
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from dgl.data import RedditDataset
13
from torch.nn.parallel import DistributedDataParallel
14
15
import tqdm

16
17
from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

35
def evaluate(model, g, nfeat, labels, train_nids, val_nids, test_nids, device):
36
37
38
39
40
41
42
43
44
45
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
46
47
        # single gpu
        if isinstance(model, SAGE):
48
            pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers)
49
50
        # multi gpu
        else:
51
            pred = model.module.inference(g, nfeat, device, args.batch_size, args.num_workers)
52
53
54
55
    model.train()
    return compute_acc(pred, labels, train_nids, val_nids, test_nids)

#### Entry point
56
def run(proc_id, n_gpus, args, devices, data):
57
    # Unpack data
58
59
60
61
62
63
64
65
66
    device = devices[proc_id]
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
                                          rank=proc_id)
67
68
69
    train_mask, val_mask, test_mask, n_classes, g = data
    nfeat = g.ndata.pop('feat')
    labels = g.ndata.pop('label')
70
71
72
    if not args.data_cpu:
        nfeat = nfeat.to(device)
        labels = labels.to(device)
73
    in_feats = nfeat.shape[1]
74

75
76
77
78
    train_nid = th.LongTensor(np.nonzero(train_mask)).squeeze()
    val_nid = th.LongTensor(np.nonzero(val_mask)).squeeze()
    test_nid = th.LongTensor(np.nonzero(test_mask)).squeeze()

79
    # Create PyTorch DataLoader for constructing blocks
80
    n_edges = g.num_edges()
81
    train_seeds = th.arange(n_edges)
82

83
84
85
86
87
    if args.sample_gpu:
        assert n_gpus > 0, "Must have GPUs to enable GPU sampling"
        train_seeds = train_seeds.to(device)
        g = g.to(device)

88
89
90
91
92
93
94
95
    # Create sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(',')])
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler, exclude='reverse_id',
        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
        reverse_eids=th.cat([
            th.arange(n_edges // 2, n_edges),
96
            th.arange(0, n_edges // 2)]).to(train_seeds),
97
        negative_sampler=NegativeSampler(g, args.num_negs, args.neg_share),
98
        device=device,
99
        use_ddp=n_gpus > 1,
100
101
102
103
104
105
106
107
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=args.num_workers)

    # Define model and optimizer
    model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
108
109
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[device], output_device=device)
110
111
112
113
114
    loss_fcn = CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
115
116
117
118
    iter_pos = []
    iter_neg = []
    iter_d = []
    iter_t = []
119
120
121
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
122
123
        if n_gpus > 1:
            dataloader.set_epoch(epoch)
124
125
126
127
128
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.

129
        tic_step = time.time()
130
        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
131
            batch_inputs = nfeat[input_nodes].to(device)
132
133
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)
134
            blocks = [block.int().to(device) for block in blocks]
135
136
            d_step = time.time()

137
138
139
140
141
142
143
            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, pos_graph, neg_graph)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

144
            t = time.time()
145
146
            pos_edges = pos_graph.num_edges()
            neg_edges = neg_graph.num_edges()
147
148
149
150
            iter_pos.append(pos_edges / (t - tic_step))
            iter_neg.append(neg_edges / (t - tic_step))
            iter_d.append(d_step - tic_step)
            iter_t.append(t - d_step)
151
152
            if step % args.log_every == 0:
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
153
                print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MB'.format(
154
155
                    proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
            tic_step = time.time()
156

157
            if step % args.eval_every == 0 and proc_id == 0:
158
                eval_acc, test_acc = evaluate(model, g, nfeat, labels, train_nid, val_nid, test_nid, device)
159
160
161
162
163
                print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    best_test_acc = test_acc
                print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
164
165
166
167
168
        toc = time.time()
        if proc_id == 0:
            print('Epoch Time(s): {:.4f}'.format(toc - tic))
        if epoch >= 5:
            avg += toc - tic
169
170
        if n_gpus > 1:
            th.distributed.barrier()
171
172
173

    if proc_id == 0:
        print('Avg epoch time: {}'.format(avg / (epoch - 4)))
174

175
176
def main(args, devices):
    # load reddit data
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
177
    data = RedditDataset(self_loop=False)
Xiangkun Hu's avatar
Xiangkun Hu committed
178
179
180
181
    n_classes = data.num_classes
    g = data[0]
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
182
    test_mask = g.ndata['test_mask']
183
184

    # Create csr/coo/csc formats before launching training processes with multi-gpu.
185
    # This avoids creating certain formats in each sub-process, which saves memory and CPU.
186
    g.create_formats_()
187
    # Pack data
188
    data = train_mask, val_mask, test_mask, n_classes, g
189
190
191
192

    n_gpus = len(devices)
    if devices[0] == -1:
        run(0, 0, args, ['cpu'], data)
193
    elif n_gpus == 1:
194
195
196
197
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
198
            p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
199
200
201
202
203
204
            p.start()
            procs.append(p)
        for p in procs:
            p.join()


205
206
if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
207
    argparser.add_argument("--gpu", type=str, default='0',
208
                           help="GPU, can be a list of gpus for multi-gpu training,"
209
                                " e.g., 0,1,2,3; -1 for CPU")
210
211
212
213
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--num-negs', type=int, default=1)
214
    argparser.add_argument('--neg-share', default=False, action='store_true',
215
                           help="sharing neg nodes for positive nodes")
216
217
218
219
220
221
222
    argparser.add_argument('--fan-out', type=str, default='10,25')
    argparser.add_argument('--batch-size', type=int, default=10000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=1000)
    argparser.add_argument('--lr', type=float, default=0.003)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
223
                           help="Number of sampling processes. Use 0 for no extra process.")
224
225
226
227
228
229
230
    argparser.add_argument('--sample-gpu', action='store_true',
                           help="Perform the sampling process on the GPU. Must have 0 workers.")
    argparser.add_argument('--data-cpu', action='store_true',
                           help="By default the script puts all node features and labels "
                                "on GPU when using it to save time for data copy. This may "
                                "be undesired if they cannot fit in GPU memory at once. "
                                "This flag disables that.")
231
    args = argparser.parse_args()
232

233
    devices = list(map(int, args.gpu.split(',')))
234

235
    main(args, devices)