train_sampling_unsupervised.py 13.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
16
from torch.nn.parallel import DistributedDataParallel
17
18
19
20
21
import tqdm
import traceback
import sklearn.linear_model as lm
import sklearn.metrics as skm

22
23
from utils import thread_wrapped_func

24
class NegativeSampler(object):
25
    def __init__(self, g, k, neg_share=False):
26
        self.weights = g.in_degrees().float() ** 0.75
27
        self.k = k
28
        self.neg_share = neg_share
29

30
31
32
33
34
35
    def __call__(self, g, eids):
        src, _ = g.find_edges(eids)
        n = len(src)
        if self.neg_share and n % self.k == 0:
            dst = self.weights.multinomial(n, replacement=True)
            dst = dst.view(-1, 1, self.k).expand(-1, self.k, -1).flatten()
36
        else:
37
38
39
            dst = self.weights.multinomial(n, replacement=True)
        src = src.repeat_interleave(self.k)
        return src, dst
40

41
42
43
44
45
46
47
def load_subtensor(g, input_nodes, device):
    """
    Copys features and labels of a set of nodes onto GPU.
    """
    batch_inputs = g.ndata['features'][input_nodes].to(device)
    return batch_inputs

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class SAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
71
            h = layer(block, h)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        nodes = th.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

95
96
97
98
99
100
101
102
103
104
105
106
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            dataloader = dgl.dataloading.NodeDataLoader(
                g,
                th.arange(g.number_of_nodes()),
                sampler,
                batch_size=args.batch_size,
                shuffle=True,
                drop_last=False,
                num_workers=args.num_workers)

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0].to(device)
107
108

                h = x[input_nodes].to(device)
109
                h = layer(block, h)
110
111
112
113
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)

114
                y[output_nodes] = h.cpu()
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139

            x = y
        return y

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

def compute_acc(emb, labels, train_nids, val_nids, test_nids):
    """
    Compute the accuracy of prediction given the labels.
    """
    emb = emb.cpu().numpy()
140
    labels = labels.cpu().numpy()
141
    train_nids = train_nids.cpu().numpy()
142
    train_labels = labels[train_nids]
143
    val_nids = val_nids.cpu().numpy()
144
    val_labels = labels[val_nids]
145
    test_nids = test_nids.cpu().numpy()
146
    test_labels = labels[test_nids]
147
148
149
150

    emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)

    lr = lm.LogisticRegression(multi_class='multinomial', max_iter=10000)
151
    lr.fit(emb[train_nids], train_labels)
152
153

    pred = lr.predict(emb)
154
155
    f1_micro_eval = skm.f1_score(val_labels, pred[val_nids], average='micro')
    f1_micro_test = skm.f1_score(test_labels, pred[test_nids], average='micro')
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    return f1_micro_eval, f1_micro_test

def evaluate(model, g, inputs, labels, train_nids, val_nids, test_nids, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
170
171
172
173
174
175
        # single gpu
        if isinstance(model, SAGE):
            pred = model.inference(g, inputs, batch_size, device)
        # multi gpu
        else:
            pred = model.module.inference(g, inputs, batch_size, device)
176
177
178
179
    model.train()
    return compute_acc(pred, labels, train_nids, val_nids, test_nids)

#### Entry point
180
def run(proc_id, n_gpus, args, devices, data):
181
    # Unpack data
182
183
184
185
186
187
188
189
190
    device = devices[proc_id]
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
                                          rank=proc_id)
191
192
    train_mask, val_mask, test_mask, in_feats, labels, n_classes, g = data

193
194
195
196
197
198
199
    train_nid = th.LongTensor(np.nonzero(train_mask)).squeeze()
    val_nid = th.LongTensor(np.nonzero(val_mask)).squeeze()
    test_nid = th.LongTensor(np.nonzero(test_mask)).squeeze()

    #train_nid = th.LongTensor(np.nonzero(train_mask)[0])
    #val_nid = th.LongTensor(np.nonzero(val_mask)[0])
    #test_nid = th.LongTensor(np.nonzero(test_mask)[0])
200
201

    # Create PyTorch DataLoader for constructing blocks
202
203
    n_edges = g.number_of_edges()
    train_seeds = np.arange(n_edges)
204
205
206
207
208
209
210
    if n_gpus > 0:
        num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
        train_seeds = train_seeds[proc_id * num_per_gpu :
                                  (proc_id + 1) * num_per_gpu \
                                  if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
                                  else train_seeds.shape[0]]

211
212
213
214
215
216
217
218
219
220
    # Create sampler
    sampler = dgl.dataloading.MultiLayerNeighborSampler(
        [int(fanout) for fanout in args.fan_out.split(',')])
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler, exclude='reverse_id',
        # For each edge with ID e in Reddit dataset, the reverse edge is e ± |E|/2.
        reverse_eids=th.cat([
            th.arange(n_edges // 2, n_edges),
            th.arange(0, n_edges // 2)]),
        negative_sampler=NegativeSampler(g, args.num_negs),
221
222
223
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
224
        pin_memory=True,
225
226
227
228
229
        num_workers=args.num_workers)

    # Define model and optimizer
    model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
230
231
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[device], output_device=device)
232
233
234
235
236
237
    loss_fcn = CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
238
239
240
241
    iter_pos = []
    iter_neg = []
    iter_d = []
    iter_t = []
242
243
244
245
246
247
248
249
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.

250
        tic_step = time.time()
251
        for step, (input_nodes, pos_graph, neg_graph, blocks) in enumerate(dataloader):
252
253
            batch_inputs = load_subtensor(g, input_nodes, device)
            d_step = time.time()
254

255
256
            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device)
257
            blocks = [block.int().to(device) for block in blocks]
258
259
260
261
262
263
264
            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, pos_graph, neg_graph)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

265
266
267
268
269
270
271
            t = time.time()
            pos_edges = pos_graph.number_of_edges()
            neg_edges = neg_graph.number_of_edges()
            iter_pos.append(pos_edges / (t - tic_step))
            iter_neg.append(neg_edges / (t - tic_step))
            iter_d.append(d_step - tic_step)
            iter_t.append(t - d_step)
272
273
            if step % args.log_every == 0:
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
274
275
276
                print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MiB'.format(
                    proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
            tic_step = time.time()
277

278
            if step % args.eval_every == 0 and proc_id == 0:
279
280
281
282
283
284
                eval_acc, test_acc = evaluate(model, g, g.ndata['features'], labels, train_nid, val_nid, test_nid, args.batch_size, device)
                print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    best_test_acc = test_acc
                print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
285
286
        if n_gpus > 1:
            th.distributed.barrier()
287
288
    print('Avg epoch time: {}'.format(avg / (epoch - 4)))

289
290
291
def main(args, devices):
    # load reddit data
    data = RedditDataset(self_loop=True)
Xiangkun Hu's avatar
Xiangkun Hu committed
292
293
294
    n_classes = data.num_classes
    g = data[0]
    features = g.ndata['feat']
295
    in_feats = features.shape[1]
Xiangkun Hu's avatar
Xiangkun Hu committed
296
297
298
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
299
    test_mask = g.ndata['test_mask']
300
    g.ndata['features'] = features
301
    g.create_format_()
302
303
304
305
306
307
    # Pack data
    data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g

    n_gpus = len(devices)
    if devices[0] == -1:
        run(0, 0, args, ['cpu'], data)
308
    elif n_gpus == 1:
309
310
311
312
313
314
315
316
317
318
319
320
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
            p = mp.Process(target=thread_wrapped_func(run),
                           args=(proc_id, n_gpus, args, devices, data))
            p.start()
            procs.append(p)
        for p in procs:
            p.join()


321
322
if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
323
324
    argparser.add_argument("--gpu", type=str, default='0',
            help="GPU, can be a list of gpus for multi-gpu trianing, e.g., 0,1,2,3; -1 for CPU")
325
326
327
328
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--num-negs', type=int, default=1)
329
330
    argparser.add_argument('--neg-share', default=False, action='store_true',
        help="sharing neg nodes for positive nodes")
331
332
333
334
335
336
337
338
339
    argparser.add_argument('--fan-out', type=str, default='10,25')
    argparser.add_argument('--batch-size', type=int, default=10000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=1000)
    argparser.add_argument('--lr', type=float, default=0.003)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
        help="Number of sampling processes. Use 0 for no extra process.")
    args = argparser.parse_args()
340

341
    devices = list(map(int, args.gpu.split(',')))
342

343
    main(args, devices)