random_partition.py 2.2 KB
Newer Older
1
# Requires setting PYTHONPATH=${GITROOT}/tools
2
import argparse
3
4
5
import json
import logging
import os
6
import sys
7

8
import numpy as np
9
from base import PartitionMeta, dump_partition_meta
10
11
from distpartitioning import array_readwriter
from files import setdir
12
13

def _random_partition(metadata, num_parts):
14
15
    num_nodes_per_type = [sum(_) for _ in metadata["num_nodes_per_chunk"]]
    ntypes = metadata["node_type"]
16
    for ntype, n in zip(ntypes, num_nodes_per_type):
17
        logging.info("Generating partition for node type %s" % ntype)
18
        parts = np.random.randint(0, num_parts, (n,))
19
20
21
22
        array_readwriter.get_array_parser(name="csv").write(
            ntype + ".txt", parts
        )

23
24
25
26
27
28
29
30
31
32

def random_partition(metadata, num_parts, output_path):
    """
    Randomly partition the graph described in metadata and generate partition ID mapping
    in :attr:`output_path`.

    A directory will be created at :attr:`output_path` containing the partition ID
    mapping files named "<node-type>.txt" (e.g. "author.txt", "paper.txt" and
    "institution.txt" for OGB-MAG240M).  Each file contains one line per node representing
    the partition ID the node belongs to.
33
    In addition, metadata which includes version, number of partitions is dumped.
34
35
36
    """
    with setdir(output_path):
        _random_partition(metadata, num_parts)
37
38
39
40
41
        part_meta = PartitionMeta(
            version="1.0.0", num_parts=num_parts, algo_name="random"
        )
        dump_partition_meta(part_meta, "partition_meta.json")

42
43
44

# Run with PYTHONPATH=${GIT_ROOT_DIR}/tools
# where ${GIT_ROOT_DIR} is the directory to the DGL git repository.
45
if __name__ == "__main__":
46
47
    parser = argparse.ArgumentParser()
    parser.add_argument(
48
49
50
51
52
        "--in_dir",
        type=str,
        help="input directory that contains the metadata file",
    )
    parser.add_argument("--out_dir", type=str, help="output directory")
53
    parser.add_argument(
54
55
56
        "--num_partitions", type=int, help="number of partitions"
    )
    logging.basicConfig(level="INFO")
57
    args = parser.parse_args()
58
    with open(os.path.join(args.in_dir, "metadata.json")) as f:
59
60
        metadata = json.load(f)
    num_parts = args.num_partitions
61
    random_partition(metadata, num_parts, args.out_dir)