gcn_ns_sampler.py 2.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.contrib.sampling import SamplerPool
import time

class MySamplerPool(SamplerPool):
    def worker(self, args):
        """User-defined worker function
        """
        # Start sender
        sender = dgl.contrib.sampling.SamplerSender(ip=args.ip, port=args.port)

        # load and preprocess dataset
        data = load_data(args)

        ctx = mx.cpu()

        if args.self_loop and not args.dataset.startswith('reddit'):
            data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])

        train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx)
        test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx)

        # create GCN model
        g = DGLGraph(data.graph, readonly=True)

        for epoch in range(args.n_epochs):
            # Here we onlt send nodeflow for training
            idx = 0
            for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
                                                           args.num_neighbors,
                                                           neighbor_type='in',
                                                           shuffle=True,
                                                           num_hops=args.n_layers+1,
                                                           seed_nodes=train_nid):
                print("send train nodeflow: %d" %(idx))
                sender.send(nf)
                idx += 1
        

def main(args):
    pool = MySamplerPool()
    pool.start(args.num_sender, args)
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
    parser.add_argument("--n-epochs", type=int, default=200,
            help="number of training epochs")
    parser.add_argument("--batch-size", type=int, default=1000,
            help="batch size")
    parser.add_argument("--test-batch-size", type=int, default=1000,
            help="test batch size")
    parser.add_argument("--num-neighbors", type=int, default=3,
            help="number of neighbors to be sampled")
    parser.add_argument("--self-loop", action='store_true',
            help="graph self-loop (default=False)")
    parser.add_argument("--n-layers", type=int, default=1,
            help="number of hidden gcn layers")
    parser.add_argument("--ip", type=str, default='127.0.0.1',
            help="ip address of remote trainer machine")
    parser.add_argument("--port", type=int, default=2049,
            help="listen port of remote trainer machine")
    parser.add_argument("--num-sender", type=int, default=1,
            help="total number of sampler sender")
    args = parser.parse_args()

    print(args)

    main(args)