"docs/source/_toctree.yml" did not exist on "d05b508356914ed8a576b9ec78708cd910529d34"
sampler.py 3.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import argparse, time, math
import numpy as np
import mxnet as mx
from mxnet import gluon
from functools import partial
import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.contrib.sampling import SamplerPool
import time

class MySamplerPool(SamplerPool):
    def worker(self, args):
        """User-defined worker function
        """
17
18
19
20
21
22
23
24
25
26
27
28
29
30
        is_shuffle = True
        self_loop = False;
        number_hops = 1

        if args.model == "gcn_ns":
            number_hops = args.n_layers + 1
        elif args.model == "gcn_cv":
            number_hops = args.n_layers
        elif args.model == "graphsage_cv":
            num_hops = args.n_layers
            self_loop = True
        else:
            print("unknown model. Please choose from gcn_ns, gcn_cv, graphsage_cv")

31
        # Start sender
32
33
        namebook = { 0:args.ip }
        sender = dgl.contrib.sampling.SamplerSender(namebook)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

        # load and preprocess dataset
        data = load_data(args)

        ctx = mx.cpu()

        if args.self_loop and not args.dataset.startswith('reddit'):
            data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])

        train_nid = mx.nd.array(np.nonzero(data.train_mask)[0]).astype(np.int64).as_in_context(ctx)
        test_nid = mx.nd.array(np.nonzero(data.test_mask)[0]).astype(np.int64).as_in_context(ctx)

        # create GCN model
        g = DGLGraph(data.graph, readonly=True)

49
        while True:
50
51
52
53
            idx = 0
            for nf in dgl.contrib.sampling.NeighborSampler(g, args.batch_size,
                                                           args.num_neighbors,
                                                           neighbor_type='in',
54
55
56
57
                                                           shuffle=is_shuffle,
                                                           num_workers=32,
                                                           num_hops=number_hops,
                                                           add_self_loop=self_loop,
58
59
                                                           seed_nodes=train_nid):
                print("send train nodeflow: %d" %(idx))
60
                sender.send(nf, 0)
61
                idx += 1
62
            sender.signal(0)
63
64
65
        
def main(args):
    pool = MySamplerPool()
66
    pool.start(args.num_sampler, args)
67
68
69
70
 
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
71
72
    parser.add_argument("--model", type=str,
                        help="select a model. Valid models: gcn_ns, gcn_cv, graphsage_cv")
73
74
75
76
77
78
79
80
    parser.add_argument("--batch-size", type=int, default=1000,
            help="batch size")
    parser.add_argument("--num-neighbors", type=int, default=3,
            help="number of neighbors to be sampled")
    parser.add_argument("--self-loop", action='store_true',
            help="graph self-loop (default=False)")
    parser.add_argument("--n-layers", type=int, default=1,
            help="number of hidden gcn layers")
81
    parser.add_argument("--ip", type=str, default='127.0.0.1:50051',
82
83
84
            help="IP address")
    parser.add_argument("--num-sampler", type=int, default=1,
            help="number of sampler")
85
86
87
88
    args = parser.parse_args()

    print(args)

89
    main(args)