partition_graph.py 2.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
r"""
Copyright (c) 2021 Intel Corporation
 \file Graph partitioning
 \brief Calls Libra - Vertex-cut based graph partitioner for distirbuted training
 \author Vasimuddin Md <vasimuddin.md@intel.com>,
         Guixiang Ma <guixiang.ma@intel.com>
         Sanchit Misra <sanchit.misra@intel.com>,
         Ramanarayan Mohanty <ramanarayan.mohanty@intel.com>,
         Sasikanth Avancha <sasikanth.avancha@intel.com>
         Nesreen K. Ahmed <nesreen.k.ahmed@intel.com>
"""


14
import argparse
15
import csv
16
import os
17
import random
18
import sys
19
import time
20
21
22
23
24
from statistics import mean

import numpy as np

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
25
from load_graph import load_ogb
26

27
import dgl
28
from dgl.base import DGLError
29
30
31
32
33
34
from dgl.data import load_data
from dgl.distgnn.partition import partition_graph
from dgl.distgnn.tools import load_proteins

if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
35
36
37
    argparser.add_argument("--dataset", type=str, default="cora")
    argparser.add_argument("--num-parts", type=int, default=2)
    argparser.add_argument("--out-dir", type=str, default="./")
38
39
40
41
    args = argparser.parse_args()

    dataset = args.dataset
    num_community = args.num_parts
42
    out_dir = "Libra_result_" + dataset  ## "Libra_result_" prefix is mandatory
43
44
45
    resultdir = os.path.join(args.out_dir, out_dir)

    print("Input dataset for partitioning: ", dataset)
46
    if args.dataset == "ogbn-products":
47
        print("Loading ogbn-products")
48
49
        G, _ = load_ogb("ogbn-products")
    elif args.dataset == "ogbn-papers100M":
50
        print("Loading ogbn-papers100M")
51
52
53
54
        G, _ = load_ogb("ogbn-papers100M")
    elif args.dataset == "proteins":
        G = load_proteins("proteins")
    elif args.dataset == "ogbn-arxiv":
55
        print("Loading ogbn-arxiv")
56
        G, _ = load_ogb("ogbn-arxiv")
57
58
59
60
61
62
63
64
65
    else:
        try:
            G = load_data(args)[0]
        except:
            raise DGLError("Error: Dataset {} not found !!!".format(dataset))

    print("Done loading the graph.", flush=True)

    partition_graph(num_community, G, resultdir)