run_store_server.py 3.62 KB
Newer Older
1
import os
2
3
4
5
6
7
8
9
10
import argparse, time, math
import numpy as np
from scipy import sparse as spsp
import mxnet as mx
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args, load_data

class GraphData:
11
    def __init__(self, csr, num_feats, graph_name):
12
13
14
15
        num_nodes = csr.shape[0]
        num_edges = mx.nd.contrib.getnnz(csr).asnumpy()[0]
        edge_ids = np.arange(0, num_edges, step=1, dtype=np.int64)
        self.graph = dgl.graph_index.GraphIndex(multigraph=False, readonly=True)
16
17
18
        self.graph.from_csr_matrix(dgl.utils.toindex(csr.indptr),
                                   dgl.utils.toindex(csr.indices), "in",
                                   dgl.contrib.graph_store._get_graph_path(graph_name))
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
        self.features = mx.nd.random.normal(shape=(csr.shape[0], num_feats))
        self.num_labels = 10
        self.labels = mx.nd.floor(mx.nd.random.uniform(low=0, high=self.num_labels,
                                                       shape=(csr.shape[0])))
        self.train_mask = np.zeros((num_nodes,))
        self.train_mask[np.arange(0, int(num_nodes/2), dtype=np.int64)] = 1
        self.val_mask = np.zeros((num_nodes,))
        self.val_mask[np.arange(int(num_nodes/2), int(num_nodes/4*3), dtype=np.int64)] = 1
        self.test_mask = np.zeros((num_nodes,))
        self.test_mask[np.arange(int(num_nodes/4*3), int(num_nodes), dtype=np.int64)] = 1

def main(args):
    # load and preprocess dataset
    if args.graph_file != '':
        csr = mx.nd.load(args.graph_file)[0]
        n_edges = csr.shape[0]
35
36
        graph_name = os.path.basename(args.graph_file)
        data = GraphData(csr, args.num_feats, graph_name)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        csr = None
    else:
        data = load_data(args)
        n_edges = data.graph.number_of_edges()
        graph_name = args.dataset

    if args.self_loop and not args.dataset.startswith('reddit'):
        data.graph.add_edges_from([(i,i) for i in range(len(data.graph))])

    mem_ctx = mx.cpu()

    features = mx.nd.array(data.features, ctx=mem_ctx)
    labels = mx.nd.array(data.labels, ctx=mem_ctx)
    train_mask = mx.nd.array(data.train_mask, ctx=mem_ctx)
    val_mask = mx.nd.array(data.val_mask, ctx=mem_ctx)
    test_mask = mx.nd.array(data.test_mask, ctx=mem_ctx)
    n_classes = data.num_labels

    n_train_samples = train_mask.sum().asscalar()
    n_val_samples = val_mask.sum().asscalar()
    n_test_samples = test_mask.sum().asscalar()

    print("""----Data statistics------'
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
              n_train_samples,
              n_val_samples,
              n_test_samples))

    # create GCN model
71
    print('graph name: ' + graph_name)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    g = dgl.contrib.graph_store.create_graph_store_server(data.graph, graph_name, "shared_mem",
                                                          args.num_workers, False)
    g.ndata['features'] = features
    g.ndata['labels'] = labels
    g.ndata['train_mask'] = train_mask
    g.ndata['val_mask'] = val_mask
    g.ndata['test_mask'] = test_mask
    g.run()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
    register_data_args(parser)
    parser.add_argument("--graph-file", type=str, default="",
            help="graph file")
    parser.add_argument("--num-feats", type=int, default=100,
            help="the number of features")
    parser.add_argument("--self-loop", action='store_true',
            help="graph self-loop (default=False)")
    parser.add_argument("--num-workers", type=int, default=1,
            help="the number of workers")
    args = parser.parse_args()

    main(args)