random_partition.py 2.17 KB
Newer Older
1
# Requires setting PYTHONPATH=${GITROOT}/tools
2
import argparse
3
4
5
6
import json
import logging
import os

7
import numpy as np
8
from base import dump_partition_meta, PartitionMeta
9
10
from distpartitioning import array_readwriter
from files import setdir
11

12

13
def _random_partition(metadata, num_parts):
14
    num_nodes_per_type = metadata["num_nodes_per_type"]
15
    ntypes = metadata["node_type"]
16
    for ntype, n in zip(ntypes, num_nodes_per_type):
17
        logging.info("Generating partition for node type %s" % ntype)
18
        parts = np.random.randint(0, num_parts, (n,))
19
20
21
22
        array_readwriter.get_array_parser(name="csv").write(
            ntype + ".txt", parts
        )

23
24
25
26
27
28
29
30
31
32

def random_partition(metadata, num_parts, output_path):
    """
    Randomly partition the graph described in metadata and generate partition ID mapping
    in :attr:`output_path`.

    A directory will be created at :attr:`output_path` containing the partition ID
    mapping files named "<node-type>.txt" (e.g. "author.txt", "paper.txt" and
    "institution.txt" for OGB-MAG240M).  Each file contains one line per node representing
    the partition ID the node belongs to.
33
    In addition, metadata which includes version, number of partitions is dumped.
34
35
36
    """
    with setdir(output_path):
        _random_partition(metadata, num_parts)
37
38
39
40
41
        part_meta = PartitionMeta(
            version="1.0.0", num_parts=num_parts, algo_name="random"
        )
        dump_partition_meta(part_meta, "partition_meta.json")

42
43
44

# Run with PYTHONPATH=${GIT_ROOT_DIR}/tools
# where ${GIT_ROOT_DIR} is the directory to the DGL git repository.
45
if __name__ == "__main__":
46
47
    parser = argparse.ArgumentParser()
    parser.add_argument(
48
49
50
51
52
        "--in_dir",
        type=str,
        help="input directory that contains the metadata file",
    )
    parser.add_argument("--out_dir", type=str, help="output directory")
53
    parser.add_argument(
54
55
56
        "--num_partitions", type=int, help="number of partitions"
    )
    logging.basicConfig(level="INFO")
57
    args = parser.parse_args()
58
    with open(os.path.join(args.in_dir, "metadata.json")) as f:
59
60
        metadata = json.load(f)
    num_parts = args.num_partitions
61
    random_partition(metadata, num_parts, args.out_dir)