partition_graph.py 4.24 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
2
3
import argparse
import time

4
5
6
7
8
9
import dgl
import numpy as np
import torch as th

from ogb.nodeproppred import DglNodePropPredDataset

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
10

11
def load_ogb(dataset):
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
12
    if dataset == "ogbn-mag":
13
14
        dataset = DglNodePropPredDataset(name=dataset)
        split_idx = dataset.get_idx_split()
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
15
16
17
        train_idx = split_idx["train"]["paper"]
        val_idx = split_idx["valid"]["paper"]
        test_idx = split_idx["test"]["paper"]
18
19
20
21
22
        hg_orig, labels = dataset[0]
        subgs = {}
        for etype in hg_orig.canonical_etypes:
            u, v = hg_orig.all_edges(etype=etype)
            subgs[etype] = (u, v)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
23
            subgs[(etype[2], "rev-" + etype[1], etype[0])] = (v, u)
24
        hg = dgl.heterograph(subgs)
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
25
26
        hg.nodes["paper"].data["feat"] = hg_orig.nodes["paper"].data["feat"]
        paper_labels = labels["paper"].squeeze()
27
28
29
30

        num_rels = len(hg.canonical_etypes)
        num_of_ntype = len(hg.ntypes)
        num_classes = dataset.num_classes
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
31
32
33
34
35
36
        category = "paper"
        print("Number of relations: {}".format(num_rels))
        print("Number of class: {}".format(num_classes))
        print("Number of train: {}".format(len(train_idx)))
        print("Number of valid: {}".format(len(val_idx)))
        print("Number of test: {}".format(len(test_idx)))
37
38
39
40
41
42
43

        # get target category id
        category_id = len(hg.ntypes)
        for i, ntype in enumerate(hg.ntypes):
            if ntype == category:
                category_id = i

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
44
        train_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
45
        train_mask[train_idx] = True
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
46
        val_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
47
        val_mask[val_idx] = True
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
48
        test_mask = th.zeros((hg.number_of_nodes("paper"),), dtype=th.bool)
49
        test_mask[test_idx] = True
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
50
51
52
        hg.nodes["paper"].data["train_mask"] = train_mask
        hg.nodes["paper"].data["val_mask"] = val_mask
        hg.nodes["paper"].data["test_mask"] = test_mask
53

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
54
        hg.nodes["paper"].data["labels"] = paper_labels
55
        return hg
56
    else:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
57
58
        raise ("Do not support other ogbn datasets.")

59

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
60
if __name__ == "__main__":
61
    argparser = argparse.ArgumentParser("Partition builtin graphs")
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    argparser.add_argument(
        "--dataset", type=str, default="ogbn-mag", help="datasets: ogbn-mag"
    )
    argparser.add_argument(
        "--num_parts", type=int, default=4, help="number of partitions"
    )
    argparser.add_argument(
        "--part_method", type=str, default="metis", help="the partition method"
    )
    argparser.add_argument(
        "--balance_train",
        action="store_true",
        help="balance the training size in each partition.",
    )
    argparser.add_argument(
        "--undirected",
        action="store_true",
        help="turn the graph into an undirected graph.",
    )
    argparser.add_argument(
        "--balance_edges",
        action="store_true",
        help="balance the number of edges in each partition.",
    )
    argparser.add_argument(
        "--num_trainers_per_machine",
        type=int,
        default=1,
        help="the number of trainers per machine. The trainer ids are stored\
                                in the node feature 'trainer_id'",
    )
    argparser.add_argument(
        "--output",
        type=str,
        default="data",
        help="Output path of partitioned graph.",
    )
99
100
101
    args = argparser.parse_args()

    start = time.time()
102
    g = load_ogb(args.dataset)
103

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
104
105
106
107
108
109
110
111
112
113
114
    print(
        "load {} takes {:.3f} seconds".format(args.dataset, time.time() - start)
    )
    print("|V|={}, |E|={}".format(g.number_of_nodes(), g.number_of_edges()))
    print(
        "train: {}, valid: {}, test: {}".format(
            th.sum(g.nodes["paper"].data["train_mask"]),
            th.sum(g.nodes["paper"].data["val_mask"]),
            th.sum(g.nodes["paper"].data["test_mask"]),
        )
    )
115
116

    if args.balance_train:
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
117
        balance_ntypes = {"paper": g.nodes["paper"].data["train_mask"]}
118
119
120
    else:
        balance_ntypes = None

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
121
122
123
124
125
126
127
128
129
130
    dgl.distributed.partition_graph(
        g,
        args.dataset,
        args.num_parts,
        args.output,
        part_method=args.part_method,
        balance_ntypes=balance_ntypes,
        balance_edges=args.balance_edges,
        num_trainers_per_machine=args.num_trainers_per_machine,
    )