train_sampling_unsupervised.py 15.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
16
from torch.nn.parallel import DistributedDataParallel
17
18
19
20
21
import tqdm
import traceback
import sklearn.linear_model as lm
import sklearn.metrics as skm

22
23
from utils import thread_wrapped_func

24
25
26
27
28
29
30
31
32
33
34
35
#### Negative sampler

class NegativeSampler(object):
    def __init__(self, g):
        self.weights = g.in_degrees().float() ** 0.75

    def __call__(self, num_samples):
        return self.weights.multinomial(num_samples, replacement=True)

#### Neighbor sampler

class NeighborSampler(object):
36
    def __init__(self, g, fanouts, num_negs, neg_share=False):
37
38
39
40
        self.g = g
        self.fanouts = fanouts
        self.neg_sampler = NegativeSampler(g)
        self.num_negs = num_negs
41
        self.neg_share = neg_share
42
43
44
45
46

    def sample_blocks(self, seed_edges):
        n_edges = len(seed_edges)
        seed_edges = th.LongTensor(np.asarray(seed_edges))
        heads, tails = self.g.find_edges(seed_edges)
47
48
49
50
51
52
53
54
55
        if self.neg_share and n_edges % self.num_negs == 0:
            neg_tails = self.neg_sampler(n_edges)
            neg_tails = neg_tails.view(-1, 1, self.num_negs).expand(n_edges//self.num_negs,
                                                                    self.num_negs,
                                                                    self.num_negs).flatten()
            neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
        else:
            neg_tails = self.neg_sampler(self.num_negs * n_edges)
            neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73

        # Maintain the correspondence between heads, tails and negative tails as two
        # graphs.
        # pos_graph contains the correspondence between each head and its positive tail.
        # neg_graph contains the correspondence between each head and its negative tails.
        # Both pos_graph and neg_graph are first constructed with the same node space as
        # the original graph.  Then they are compacted together with dgl.compact_graphs.
        pos_graph = dgl.graph((heads, tails), num_nodes=self.g.number_of_nodes())
        neg_graph = dgl.graph((neg_heads, neg_tails), num_nodes=self.g.number_of_nodes())
        pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])

        # Obtain the node IDs being used in either pos_graph or neg_graph.  Since they
        # are compacted together, pos_graph and neg_graph share the same compacted node
        # space.
        seeds = pos_graph.ndata[dgl.NID]
        blocks = []
        for fanout in self.fanouts:
            # For each seed node, sample ``fanout`` neighbors.
74
            frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
75
76
77
78
79
80
81
82
            # Remove all edges between heads and tails, as well as heads and neg_tails.
            _, _, edge_ids = frontier.edge_ids(
                th.cat([heads, tails, neg_heads, neg_tails]),
                th.cat([tails, heads, neg_tails, neg_heads]),
                return_uv=True)
            frontier = dgl.remove_edges(frontier, edge_ids)
            # Then we compact the frontier into a bipartite graph for message passing.
            block = dgl.to_block(frontier, seeds)
83
84
85

            # Pre-generate CSR format that it can be used in training directly
            block.in_degree(0)
86
87
88
89
            # Obtain the seed nodes for next layer.
            seeds = block.srcdata[dgl.NID]

            blocks.insert(0, block)
90
91

        # Pre-generate CSR format that it can be used in training directly
92
93
        return pos_graph, neg_graph, blocks

94
95
96
97
98
99
100
def load_subtensor(g, input_nodes, device):
    """
    Copys features and labels of a set of nodes onto GPU.
    """
    batch_inputs = g.ndata['features'][input_nodes].to(device)
    return batch_inputs

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
class SAGE(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
124
            h = layer(block, h)
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h

    def inference(self, g, x, batch_size, device):
        """
        Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
        g : the entire graph.
        x : the input of entire node set.

        The inference code is written in a fashion that it could handle any number of nodes and
        layers.
        """
        # During inference with sampling, multi-layer blocks are very inefficient because
        # lots of computations in the first few layers are repeated.
        # Therefore, we compute the representation of all nodes layer by layer.  The nodes
        # on each layer are of course splitted in batches.
        # TODO: can we standardize this?
        nodes = th.arange(g.number_of_nodes())
        for l, layer in enumerate(self.layers):
            y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

            for start in tqdm.trange(0, len(nodes), batch_size):
                end = start + batch_size
                batch_nodes = nodes[start:end]
                block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
                input_nodes = block.srcdata[dgl.NID]

                h = x[input_nodes].to(device)
155
                h = layer(block, h)
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
                if l != len(self.layers) - 1:
                    h = self.activation(h)
                    h = self.dropout(h)

                y[start:end] = h.cpu()

            x = y
        return y

class CrossEntropyLoss(nn.Module):
    def forward(self, block_outputs, pos_graph, neg_graph):
        with pos_graph.local_scope():
            pos_graph.ndata['h'] = block_outputs
            pos_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            pos_score = pos_graph.edata['score']
        with neg_graph.local_scope():
            neg_graph.ndata['h'] = block_outputs
            neg_graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            neg_score = neg_graph.edata['score']

        score = th.cat([pos_score, neg_score])
        label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long()
        loss = F.binary_cross_entropy_with_logits(score, label.float())
        return loss

def compute_acc(emb, labels, train_nids, val_nids, test_nids):
    """
    Compute the accuracy of prediction given the labels.
    """
    emb = emb.cpu().numpy()
    train_nids = train_nids.cpu().numpy()
    train_labels = labels[train_nids].cpu().numpy()
    val_nids = val_nids.cpu().numpy()
    val_labels = labels[val_nids].cpu().numpy()
    test_nids = test_nids.cpu().numpy()
    test_labels = labels[test_nids].cpu().numpy()

    emb = (emb - emb.mean(0, keepdims=True)) / emb.std(0, keepdims=True)

    lr = lm.LogisticRegression(multi_class='multinomial', max_iter=10000)
    lr.fit(emb[train_nids], labels[train_nids])

    pred = lr.predict(emb)
    f1_micro_eval = skm.f1_score(labels[val_nids], pred[val_nids], average='micro')
    f1_micro_test = skm.f1_score(labels[test_nids], pred[test_nids], average='micro')
    f1_macro_eval = skm.f1_score(labels[val_nids], pred[val_nids], average='macro')
    f1_macro_test = skm.f1_score(labels[test_nids], pred[test_nids], average='macro')
    return f1_micro_eval, f1_micro_test

def evaluate(model, g, inputs, labels, train_nids, val_nids, test_nids, batch_size, device):
    """
    Evaluate the model on the validation set specified by ``val_mask``.
    g : The entire graph.
    inputs : The features of all the nodes.
    labels : The labels of all the nodes.
    val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
    batch_size : Number of nodes to compute at the same time.
    device : The GPU device to evaluate on.
    """
    model.eval()
    with th.no_grad():
217
218
219
220
221
222
        # single gpu
        if isinstance(model, SAGE):
            pred = model.inference(g, inputs, batch_size, device)
        # multi gpu
        else:
            pred = model.module.inference(g, inputs, batch_size, device)
223
224
225
226
    model.train()
    return compute_acc(pred, labels, train_nids, val_nids, test_nids)

#### Entry point
227
def run(proc_id, n_gpus, args, devices, data):
228
    # Unpack data
229
230
231
232
233
234
235
236
237
    device = devices[proc_id]
    if n_gpus > 1:
        dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
            master_ip='127.0.0.1', master_port='12345')
        world_size = n_gpus
        th.distributed.init_process_group(backend="nccl",
                                          init_method=dist_init_method,
                                          world_size=world_size,
                                          rank=proc_id)
238
239
240
241
242
243
244
    train_mask, val_mask, test_mask, in_feats, labels, n_classes, g = data

    train_nid = th.LongTensor(np.nonzero(train_mask)[0])
    val_nid = th.LongTensor(np.nonzero(val_mask)[0])
    test_nid = th.LongTensor(np.nonzero(test_mask)[0])

    # Create sampler
245
    sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], args.num_negs, args.neg_share)
246
247

    # Create PyTorch DataLoader for constructing blocks
248
249
250
251
252
253
254
255
    train_seeds = np.arange(g.number_of_edges())
    if n_gpus > 0:
        num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
        train_seeds = train_seeds[proc_id * num_per_gpu :
                                  (proc_id + 1) * num_per_gpu \
                                  if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
                                  else train_seeds.shape[0]]

256
    dataloader = DataLoader(
257
        dataset=train_seeds,
258
259
260
261
        batch_size=args.batch_size,
        collate_fn=sampler.sample_blocks,
        shuffle=True,
        drop_last=False,
262
        pin_memory=True,
263
264
265
266
267
        num_workers=args.num_workers)

    # Define model and optimizer
    model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
    model = model.to(device)
268
269
    if n_gpus > 1:
        model = DistributedDataParallel(model, device_ids=[device], output_device=device)
270
271
272
273
274
275
    loss_fcn = CrossEntropyLoss()
    loss_fcn = loss_fcn.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # Training loop
    avg = 0
276
277
278
279
    iter_pos = []
    iter_neg = []
    iter_d = []
    iter_t = []
280
281
282
283
284
285
286
287
    best_eval_acc = 0
    best_test_acc = 0
    for epoch in range(args.num_epochs):
        tic = time.time()

        # Loop over the dataloader to sample the computation dependency graph as a list of
        # blocks.

288
289
        tic_step = time.time()
        for step, (pos_graph, neg_graph, blocks) in enumerate(dataloader):
290
291
292
            # The nodes for input lies at the LHS side of the first block.
            # The nodes for output lies at the RHS side of the last block.
            input_nodes = blocks[0].srcdata[dgl.NID]
293
294
            batch_inputs = load_subtensor(g, input_nodes, device)
            d_step = time.time()
295
296
297
298
299
300
301
302

            # Compute loss and prediction
            batch_pred = model(blocks, batch_inputs)
            loss = loss_fcn(batch_pred, pos_graph, neg_graph)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

303
304
305
306
307
308
309
            t = time.time()
            pos_edges = pos_graph.number_of_edges()
            neg_edges = neg_graph.number_of_edges()
            iter_pos.append(pos_edges / (t - tic_step))
            iter_neg.append(neg_edges / (t - tic_step))
            iter_d.append(d_step - tic_step)
            iter_t.append(t - d_step)
310
311
            if step % args.log_every == 0:
                gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
312
313
314
                print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MiB'.format(
                    proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
            tic_step = time.time()
315

316
            if step % args.eval_every == 0 and proc_id == 0:
317
318
319
320
321
322
                eval_acc, test_acc = evaluate(model, g, g.ndata['features'], labels, train_nid, val_nid, test_nid, args.batch_size, device)
                print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
                if eval_acc > best_eval_acc:
                    best_eval_acc = eval_acc
                    best_test_acc = test_acc
                print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))
323
324
        if n_gpus > 1:
            th.distributed.barrier()
325
326
    print('Avg epoch time: {}'.format(avg / (epoch - 4)))

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def main(args, devices):
    # load reddit data
    data = RedditDataset(self_loop=True)
    train_mask = data.train_mask
    val_mask = data.val_mask
    test_mask = data.test_mask
    features = th.Tensor(data.features)
    in_feats = features.shape[1]
    labels = th.LongTensor(data.labels)
    n_classes = data.num_labels
    # Construct graph
    g = dgl.graph(data.graph.all_edges())
    g.ndata['features'] = features
    # Pack data
    data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g

    n_gpus = len(devices)
    if devices[0] == -1:
        run(0, 0, args, ['cpu'], data)
    if n_gpus == 1:
        run(0, n_gpus, args, devices, data)
    else:
        procs = []
        for proc_id in range(n_gpus):
            p = mp.Process(target=thread_wrapped_func(run),
                           args=(proc_id, n_gpus, args, devices, data))
            p.start()
            procs.append(p)
        for p in procs:
            p.join()

    run(args, device, data)


361
362
if __name__ == '__main__':
    argparser = argparse.ArgumentParser("multi-gpu training")
363
364
    argparser.add_argument("--gpu", type=str, default='0',
            help="GPU, can be a list of gpus for multi-gpu trianing, e.g., 0,1,2,3; -1 for CPU")
365
366
367
368
    argparser.add_argument('--num-epochs', type=int, default=20)
    argparser.add_argument('--num-hidden', type=int, default=16)
    argparser.add_argument('--num-layers', type=int, default=2)
    argparser.add_argument('--num-negs', type=int, default=1)
369
370
    argparser.add_argument('--neg-share', default=False, action='store_true',
        help="sharing neg nodes for positive nodes")
371
372
373
374
375
376
377
378
379
380
    argparser.add_argument('--fan-out', type=str, default='10,25')
    argparser.add_argument('--batch-size', type=int, default=10000)
    argparser.add_argument('--log-every', type=int, default=20)
    argparser.add_argument('--eval-every', type=int, default=1000)
    argparser.add_argument('--lr', type=float, default=0.003)
    argparser.add_argument('--dropout', type=float, default=0.5)
    argparser.add_argument('--num-workers', type=int, default=0,
        help="Number of sampling processes. Use 0 for no extra process.")
    args = argparser.parse_args()
    
381
    devices = list(map(int, args.gpu.split(',')))
382

383
    main(args, devices)