partition_graph.py 2.71 KB
Newer Older
1
import argparse
2
import os
3
4
5
import sys
import time

6
7
import dgl

8
9
10
11
12
import numpy as np
import torch as th

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from load_graph import load_ogb, load_reddit
13

14
if __name__ == "__main__":
15
    argparser = argparse.ArgumentParser("Partition builtin graphs")
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    argparser.add_argument(
        "--dataset",
        type=str,
        default="reddit",
        help="datasets: reddit, ogb-product, ogb-paper100M",
    )
    argparser.add_argument(
        "--num_parts", type=int, default=4, help="number of partitions"
    )
    argparser.add_argument(
        "--part_method", type=str, default="metis", help="the partition method"
    )
    argparser.add_argument(
        "--balance_train",
        action="store_true",
        help="balance the training size in each partition.",
    )
    argparser.add_argument(
        "--undirected",
        action="store_true",
        help="turn the graph into an undirected graph.",
    )
    argparser.add_argument(
        "--balance_edges",
        action="store_true",
        help="balance the number of edges in each partition.",
    )
    argparser.add_argument(
        "--num_trainers_per_machine",
        type=int,
        default=1,
        help="the number of trainers per machine. The trainer ids are stored\
                                in the node feature 'trainer_id'",
    )
    argparser.add_argument(
        "--output",
        type=str,
        default="data",
        help="Output path of partitioned graph.",
    )
56
57
58
    args = argparser.parse_args()

    start = time.time()
59
    if args.dataset == "reddit":
60
        g, _ = load_reddit()
61
62
63
64
65
66
67
    elif args.dataset == "ogb-product":
        g, _ = load_ogb("ogbn-products")
    elif args.dataset == "ogb-paper100M":
        g, _ = load_ogb("ogbn-papers100M")
    print(
        "load {} takes {:.3f} seconds".format(args.dataset, time.time() - start)
    )
68
    print("|V|={}, |E|={}".format(g.num_nodes(), g.num_edges()))
69
70
71
72
73
74
75
    print(
        "train: {}, valid: {}, test: {}".format(
            th.sum(g.ndata["train_mask"]),
            th.sum(g.ndata["val_mask"]),
            th.sum(g.ndata["test_mask"]),
        )
    )
76
    if args.balance_train:
77
        balance_ntypes = g.ndata["train_mask"]
78
79
    else:
        balance_ntypes = None
80
81

    if args.undirected:
Chao Ma's avatar
Chao Ma committed
82
        sym_g = dgl.to_bidirected(g, readonly=True)
83
84
85
86
        for key in g.ndata:
            sym_g.ndata[key] = g.ndata[key]
        g = sym_g

87
88
89
90
91
92
93
94
95
96
    dgl.distributed.partition_graph(
        g,
        args.dataset,
        args.num_parts,
        args.output,
        part_method=args.part_method,
        balance_ntypes=balance_ntypes,
        balance_edges=args.balance_edges,
        num_trainers_per_machine=args.num_trainers_per_machine,
    )