partition_graph.py 5.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import dgl
import numpy as np
import torch as th
import argparse
import time

from ogb.nodeproppred import DglNodePropPredDataset

def load_ogb(dataset, global_norm):
    if dataset == 'ogbn-mag':
        dataset = DglNodePropPredDataset(name=dataset)
        split_idx = dataset.get_idx_split()
        train_idx = split_idx["train"]['paper']
        val_idx = split_idx["valid"]['paper']
        test_idx = split_idx["test"]['paper']
        hg_orig, labels = dataset[0]
        subgs = {}
        for etype in hg_orig.canonical_etypes:
            u, v = hg_orig.all_edges(etype=etype)
            subgs[etype] = (u, v)
            subgs[(etype[2], 'rev-'+etype[1], etype[0])] = (v, u)
        hg = dgl.heterograph(subgs)
        hg.nodes['paper'].data['feat'] = hg_orig.nodes['paper'].data['feat']
        paper_labels = labels['paper'].squeeze()

        num_rels = len(hg.canonical_etypes)
        num_of_ntype = len(hg.ntypes)
        num_classes = dataset.num_classes
        category = 'paper'
        print('Number of relations: {}'.format(num_rels))
        print('Number of class: {}'.format(num_classes))
        print('Number of train: {}'.format(len(train_idx)))
        print('Number of valid: {}'.format(len(val_idx)))
        print('Number of test: {}'.format(len(test_idx)))

        # currently we do not support node feature in mag dataset.
        # calculate norm for each edge type and store in edge
        if global_norm is False:
            for canonical_etype in hg.canonical_etypes:
                u, v, eid = hg.all_edges(form='all', etype=canonical_etype)
                _, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
                degrees = count[inverse_index]
                norm = th.ones(eid.shape[0]) / degrees
                norm = norm.unsqueeze(1)
                hg.edges[canonical_etype].data['norm'] = norm

        # get target category id
        category_id = len(hg.ntypes)
        for i, ntype in enumerate(hg.ntypes):
            if ntype == category:
                category_id = i

53
        g = dgl.to_homogeneous(hg, edata=['norm'])
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        if global_norm:
            u, v, eid = g.all_edges(form='all')
            _, inverse_index, count = th.unique(v, return_inverse=True, return_counts=True)
            degrees = count[inverse_index]
            norm = th.ones(eid.shape[0]) / degrees
            norm = norm.unsqueeze(1)
            g.edata['norm'] = norm

        node_ids = th.arange(g.number_of_nodes())
        # find out the target node ids
        node_tids = g.ndata[dgl.NTYPE]
        loc = (node_tids == category_id)
        target_idx = node_ids[loc]
        train_idx = target_idx[train_idx]
        val_idx = target_idx[val_idx]
        test_idx = target_idx[test_idx]
        train_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
        train_mask[train_idx] = True
        val_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
        val_mask[val_idx] = True
        test_mask = th.zeros((g.number_of_nodes(),), dtype=th.bool)
        test_mask[test_idx] = True
        g.ndata['train_mask'] = train_mask
        g.ndata['val_mask'] = val_mask
        g.ndata['test_mask'] = test_mask

        labels = th.full((g.number_of_nodes(),), -1, dtype=paper_labels.dtype)
        labels[target_idx] = paper_labels
        g.ndata['labels'] = labels
        return g
    else:
        raise("Do not support other ogbn datasets.")

if __name__ == '__main__':
    argparser = argparse.ArgumentParser("Partition builtin graphs")
    argparser.add_argument('--dataset', type=str, default='ogbn-mag',
                           help='datasets: ogbn-mag')
    argparser.add_argument('--num_parts', type=int, default=4,
                           help='number of partitions')
    argparser.add_argument('--part_method', type=str, default='metis',
                           help='the partition method')
    argparser.add_argument('--balance_train', action='store_true',
                           help='balance the training size in each partition.')
    argparser.add_argument('--undirected', action='store_true',
                           help='turn the graph into an undirected graph.')
    argparser.add_argument('--balance_edges', action='store_true',
                           help='balance the number of edges in each partition.')
    argparser.add_argument('--global-norm', default=False, action='store_true',
                           help='User global norm instead of per node type norm')
    args = argparser.parse_args()

    start = time.time()
    g = load_ogb(args.dataset, args.global_norm)

    print('load {} takes {:.3f} seconds'.format(args.dataset, time.time() - start))
    print('|V|={}, |E|={}'.format(g.number_of_nodes(), g.number_of_edges()))
    print('train: {}, valid: {}, test: {}'.format(th.sum(g.ndata['train_mask']),
                                                  th.sum(g.ndata['val_mask']),
                                                  th.sum(g.ndata['test_mask'])))

    if args.balance_train:
        balance_ntypes = g.ndata['train_mask']
    else:
        balance_ntypes = None

    dgl.distributed.partition_graph(g, args.dataset, args.num_parts, 'data',
                                    part_method=args.part_method,
                                    balance_ntypes=balance_ntypes,
                                    balance_edges=args.balance_edges)