"tests/remote/test_remote_decode.py" did not exist on "694f9658c1f511e323bf86cd88af0a2e2b0fee9b"
train_sampling_unsupervised.py 13 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from dgl.data import RedditDataset
13
from torch.nn.parallel import DistributedDataParallel
14
15
16
17
import tqdm
import sklearn.linear_model as lm
import sklearn.metrics as skm

18
19
from utils import thread_wrapped_func

20
class NegativeSampler(object):
21
    def __init__(self, g, k, neg_share=False):
22
        self.weights = g.in_degrees().float() ** 0.75
23
        self.k = k
24
        self.neg_share = neg_share
25

26
27
28
29
30
31
    def __call__(self, g, eids):
        src, _ = g.find_edges(eids)
        n = len(src)
        if self.neg_share and n % self.k == 0:
            dst = self.weights.multinomial(n, replacement=True)
            dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
32
        else:
33
            dst = self.weights.multinomial(n*self.k, replacement=True)
34
35
        src = src.repeat_interleave(self.k)
        return src, dst
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

class SAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
60
            h = layer(block, h)
61
62
63
64
65
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

66
    def inference(self, g, x, device):
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        for l, layer in enumerate(self.layers):
81
            y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
82

83
84
85
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            dataloader = dgl.dataloading.NodeDataLoader(
                g,
86
                th.arange(g.num_nodes()),
87
88
89
90
91
92
93
94
                sampler,
                batch_size=args.batch_size,
                shuffle=True,
                drop_last=False,
                num_workers=args.num_workers)

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0].to(device)
95
96

                h = x[input_nodes].to(device)
97
                h = layer(block, h)
98
99
100
101
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)

102
                y[output_nodes] = h.cpu()
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

            x = y
        return y

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

def compute_acc(emb, labels, train_nids, val_nids, test_nids):
    """
    Compute the accuracy of prediction given the labels.
    """
    emb = emb.cpu().numpy()
128
    labels = labels.cpu().numpy()
129
    train_nids = train_nids.cpu().numpy()
130
    train_labels = labels[train_nids]
131
    val_nids = val_nids.cpu().numpy()
132
    val_labels = labels[val_nids]
133
    test_nids = test_nids.cpu().numpy()
134
    test_labels = labels[test_nids]
135
136
137
138

    emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)

    lr = lm.LogisticRegression(multi_class='multinomial', max_iter=10000)
139
    lr.fit(emb[train_nids], train_labels)
140
141

    pred = lr.predict(emb)
142
143
    f1_micro_eval = skm.f1_score(val_labels, pred[val_nids], average='micro')
    f1_micro_test = skm.f1_score(test_labels, pred[test_nids], average='micro')
144
145
    return f1_micro_eval, f1_micro_test

146
def evaluate(model, g, nfeat, labels, train_nids, val_nids, test_nids, device):
147
148
149
150
151
152
153
154
155
156
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
157
158
        # single gpu
        if isinstance(model, SAGE):
159
            pred = model.inference(g, nfeat, device)
160
161
        # multi gpu
        else:
162
            pred = model.module.inference(g, nfeat, device)
163
164
165
166
    model.train()
    return compute_acc(pred, labels, train_nids, val_nids, test_nids)

#### Entry point
167
def run(proc_id, n_gpus, args, devices, data):
168
    # Unpack data
169
170
171
172
173
174
175
176
177
    device = devices[proc_id]
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
                                          rank=proc_id)
178
179
180
181
    train_mask, val_mask, test_mask, n_classes, g = data
    nfeat = g.ndata.pop('feat')
    labels = g.ndata.pop('label')
    in_feats = nfeat.shape[1]
182

183
184
185
186
    train_nid = th.LongTensor(np.nonzero(train_mask)).squeeze()
    val_nid = th.LongTensor(np.nonzero(val_mask)).squeeze()
    test_nid = th.LongTensor(np.nonzero(test_mask)).squeeze()

187
    # Create PyTorch DataLoader for constructing blocks
188
    n_edges = g.num_edges()
189
    train_seeds = np.arange(n_edges)
190
191
192
193
194
195
196
    if n_gpus > 0:
        num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
        train_seeds = train_seeds[proc_id * num_per_gpu :
                                  (proc_id + 1) * num_per_gpu \
                                  if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
                                  else train_seeds.shape[0]]

197
198
199
200
201
202
203
204
205
206
    # Create sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(',')])
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler, exclude='reverse_id',
        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
        reverse_eids=th.cat([
            th.arange(n_edges // 2, n_edges),
            th.arange(0, n_edges // 2)]),
        negative_sampler=NegativeSampler(g, args.num_negs),
207
208
209
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
210
        pin_memory=True,
211
212
213
214
215
        num_workers=args.num_workers)

    # Define model and optimizer
    model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
216
217
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[device], output_device=device)
218
219
220
221
222
    loss_fcn = CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
223
224
225
226
    iter_pos = []
    iter_neg = []
    iter_d = []
    iter_t = []
227
228
229
230
231
232
233
234
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.

235
        tic_step = time.time()
236
        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
237
            batch_inputs = nfeat[input_nodes].to(device)
238
            d_step = time.time()
239

240
241
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)
242
            blocks = [block.int().to(device) for block in blocks]
243
244
245
246
247
248
249
            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, pos_graph, neg_graph)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

250
            t = time.time()
251
252
            pos_edges = pos_graph.num_edges()
            neg_edges = neg_graph.num_edges()
253
254
255
256
            iter_pos.append(pos_edges / (t - tic_step))
            iter_neg.append(neg_edges / (t - tic_step))
            iter_d.append(d_step - tic_step)
            iter_t.append(t - d_step)
257
258
            if step % args.log_every == 0:
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
maqy1995's avatar
maqy1995 committed
259
                print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MB'.format(
260
261
                    proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
            tic_step = time.time()
262

263
            if step % args.eval_every == 0 and proc_id == 0:
264
                eval_acc, test_acc = evaluate(model, g, nfeat, labels, train_nid, val_nid, test_nid, device)
265
266
267
268
269
                print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    best_test_acc = test_acc
                print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
270
271
272
273
274
        toc = time.time()
        if proc_id == 0:
            print('Epoch Time(s): {:.4f}'.format(toc - tic))
        if epoch >= 5:
            avg += toc - tic
275
276
        if n_gpus > 1:
            th.distributed.barrier()
277
278
279

    if proc_id == 0:
        print('Avg epoch time: {}'.format(avg / (epoch - 4)))
280

281
282
def main(args, devices):
    # load reddit data
Quan (Andy) Gan's avatar
Quan (Andy) Gan committed
283
    data = RedditDataset(self_loop=False)
Xiangkun Hu's avatar
Xiangkun Hu committed
284
285
286
287
    n_classes = data.num_classes
    g = data[0]
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
288
    test_mask = g.ndata['test_mask']
289
290
291

    # Create csr/coo/csc formats before launching training processes with multi-gpu.
    # This avoids creating certain formats in each sub-process, which saves momory and CPU.
292
    g.create_formats_()
293
    # Pack data
294
    data = train_mask, val_mask, test_mask, n_classes, g
295
296
297
298

    n_gpus = len(devices)
    if devices[0] == -1:
        run(0, 0, args, ['cpu'], data)
299
    elif n_gpus == 1:
300
301
302
303
304
305
306
307
308
309
310
311
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
            p = mp.Process(target=thread_wrapped_func(run),
                           args=(proc_id, n_gpus, args, devices, data))
            p.start()
            procs.append(p)
        for p in procs:
            p.join()


312
313
if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
314
    argparser.add_argument("--gpu", type=str, default='0',
315
316
                           help="GPU, can be a list of gpus for multi-gpu trianing,"
                                " e.g., 0,1,2,3; -1 for CPU")
317
318
319
320
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--num-negs', type=int, default=1)
321
    argparser.add_argument('--neg-share', default=False, action='store_true',
322
                           help="sharing neg nodes for positive nodes")
323
324
325
326
327
328
329
    argparser.add_argument('--fan-out', type=str, default='10,25')
    argparser.add_argument('--batch-size', type=int, default=10000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=1000)
    argparser.add_argument('--lr', type=float, default=0.003)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
330
                           help="Number of sampling processes. Use 0 for no extra process.")
331
    args = argparser.parse_args()
332

333
    devices = list(map(int, args.gpu.split(',')))
334

335
    main(args, devices)