partition_graph.py 4.39 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import time

import dgl
import numpy as np
import torch as th

from ogb.nodeproppred import DglNodePropPredDataset


def load_ogb(dataset):
    if dataset == "ogbn-mag":
        dataset = DglNodePropPredDataset(name=dataset)
        split_idx = dataset.get_idx_split()
        train_idx = split_idx["train"]["paper"]
        val_idx = split_idx["valid"]["paper"]
        test_idx = split_idx["test"]["paper"]
        hg_orig, labels = dataset[0]
        subgs = {}
        for etype in hg_orig.canonical_etypes:
            u, v = hg_orig.all_edges(etype=etype)
            subgs[etype] = (u, v)
            subgs[(etype[2], "rev-" + etype[1], etype[0])] = (v, u)
        hg = dgl.heterograph(subgs)
        hg.nodes["paper"].data["feat"] = hg_orig.nodes["paper"].data["feat"]
        paper_labels = labels["paper"].squeeze()

        num_rels = len(hg.canonical_etypes)
        num_of_ntype = len(hg.ntypes)
        num_classes = dataset.num_classes
        category = "paper"
        print("Number of relations: {}".format(num_rels))
        print("Number of class: {}".format(num_classes))
        print("Number of train: {}".format(len(train_idx)))
        print("Number of valid: {}".format(len(val_idx)))
        print("Number of test: {}".format(len(test_idx)))

        # get target category id
        category_id = len(hg.ntypes)
        for i, ntype in enumerate(hg.ntypes):
            if ntype == category:
                category_id = i

        train_mask = th.zeros((hg.num_nodes("paper"),), dtype=th.bool)
        train_mask[train_idx] = True
        val_mask = th.zeros((hg.num_nodes("paper"),), dtype=th.bool)
        val_mask[val_idx] = True
        test_mask = th.zeros((hg.num_nodes("paper"),), dtype=th.bool)
        test_mask[test_idx] = True
        hg.nodes["paper"].data["train_mask"] = train_mask
        hg.nodes["paper"].data["val_mask"] = val_mask
        hg.nodes["paper"].data["test_mask"] = test_mask

        hg.nodes["paper"].data["labels"] = paper_labels
        return hg
    else:
        raise ("Do not support other ogbn datasets.")


if __name__ == "__main__":
    argparser = argparse.ArgumentParser("Partition builtin graphs")
    argparser.add_argument(
        "--dataset", type=str, default="ogbn-mag", help="datasets: ogbn-mag"
    )
    argparser.add_argument(
        "--num_parts", type=int, default=4, help="number of partitions"
    )
    argparser.add_argument(
        "--part_method", type=str, default="metis", help="the partition method"
    )
    argparser.add_argument(
        "--balance_train",
        action="store_true",
        help="balance the training size in each partition.",
    )
    argparser.add_argument(
        "--undirected",
        action="store_true",
        help="turn the graph into an undirected graph.",
    )
    argparser.add_argument(
        "--balance_edges",
        action="store_true",
        help="balance the number of edges in each partition.",
    )
    argparser.add_argument(
        "--num_trainers_per_machine",
        type=int,
        default=1,
        help="the number of trainers per machine. The trainer ids are stored\
                                in the node feature 'trainer_id'",
    )
    argparser.add_argument(
        "--output",
        type=str,
        default="data",
        help="Output path of partitioned graph.",
    )
    argparser.add_argument(
        "--use_graphbolt",
        action="store_true",
        help="Use GraphBolt for distributed train.",
    )

    args = argparser.parse_args()

    start = time.time()
    g = load_ogb(args.dataset)

    print(
        "load {} takes {:.3f} seconds".format(args.dataset, time.time() - start)
    )
    print("|V|={}, |E|={}".format(g.num_nodes(), g.num_edges()))
    print(
        "train: {}, valid: {}, test: {}".format(
            th.sum(g.nodes["paper"].data["train_mask"]),
            th.sum(g.nodes["paper"].data["val_mask"]),
            th.sum(g.nodes["paper"].data["test_mask"]),
        )
    )

    if args.balance_train:
        balance_ntypes = {"paper": g.nodes["paper"].data["train_mask"]}
    else:
        balance_ntypes = None

    dgl.distributed.partition_graph(
        g,
        args.dataset,
        args.num_parts,
        args.output,
        part_method=args.part_method,
        balance_ntypes=balance_ntypes,
        balance_edges=args.balance_edges,
        num_trainers_per_machine=args.num_trainers_per_machine,
        use_graphbolt=args.use_graphbolt,
    )