partition_graph.py 2.79 KB
Newer Older
1
2
3
4
5
import dgl
import numpy as np
import torch as th
import argparse
import time
6
import sys
7
8
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
9
10
11
12
13
14
15
16
from load_graph import load_reddit, load_ogb

if __name__ == '__main__':
    argparser = argparse.ArgumentParser("Partition builtin graphs")
    argparser.add_argument('--dataset', type=str, default='reddit',
                           help='datasets: reddit, ogb-product, ogb-paper100M')
    argparser.add_argument('--num_parts', type=int, default=4,
                           help='number of partitions')
17
18
    argparser.add_argument('--part_method', type=str, default='metis',
                           help='the partition method')
19
20
    argparser.add_argument('--balance_train', action='store_true',
                           help='balance the training size in each partition.')
21
22
    argparser.add_argument('--undirected', action='store_true',
                           help='turn the graph into an undirected graph.')
23
24
    argparser.add_argument('--balance_edges', action='store_true',
                           help='balance the number of edges in each partition.')
25
26
27
    argparser.add_argument('--num_trainers_per_machine', type=int, default=1,
                           help='the number of trainers per machine. The trainer ids are stored\
                                in the node feature \'trainer_id\'')
28
29
    argparser.add_argument('--output', type=str, default='data',
                           help='Output path of partitioned graph.')
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    args = argparser.parse_args()

    start = time.time()
    if args.dataset == 'reddit':
        g, _ = load_reddit()
    elif args.dataset == 'ogb-product':
        g, _ = load_ogb('ogbn-products')
    elif args.dataset == 'ogb-paper100M':
        g, _ = load_ogb('ogbn-papers100M')
    print('load {} takes {:.3f} seconds'.format(args.dataset, time.time() - start))
    print('|V|={}, |E|={}'.format(g.number_of_nodes(), g.number_of_edges()))
    print('train: {}, valid: {}, test: {}'.format(th.sum(g.ndata['train_mask']),
                                                  th.sum(g.ndata['val_mask']),
                                                  th.sum(g.ndata['test_mask'])))
    if args.balance_train:
        balance_ntypes = g.ndata['train_mask']
    else:
        balance_ntypes = None
48
49

    if args.undirected:
Chao Ma's avatar
Chao Ma committed
50
        sym_g = dgl.to_bidirected(g, readonly=True)
51
52
53
54
        for key in g.ndata:
            sym_g.ndata[key] = g.ndata[key]
        g = sym_g

55
    dgl.distributed.partition_graph(g, args.dataset, args.num_parts, args.output,
56
                                    part_method=args.part_method,
57
                                    balance_ntypes=balance_ntypes,
58
59
                                    balance_edges=args.balance_edges,
                                    num_trainers_per_machine=args.num_trainers_per_machine)