"vscode:/vscode.git/clone" did not exist on "b110453802779285a8fb9dca6808f34cddbf68ee"
random_partition.py 2.14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# Requires setting PYTHONPATH=${GITROOT}/tools
import json
import logging
import sys
import os
import numpy as np
import argparse

from utils import setdir
from utils import array_readwriter
11
from base import PartitionMeta, dump_partition_meta
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

def _random_partition(metadata, num_parts):
    num_nodes_per_type = [sum(_) for _ in metadata['num_nodes_per_chunk']]
    ntypes = metadata['node_type']
    for ntype, n in zip(ntypes, num_nodes_per_type):
        logging.info('Generating partition for node type %s' % ntype)
        parts = np.random.randint(0, num_parts, (n,))
        array_readwriter.get_array_parser(name='csv').write(ntype + '.txt', parts)

def random_partition(metadata, num_parts, output_path):
    """
    Randomly partition the graph described in metadata and generate partition ID mapping
    in :attr:`output_path`.

    A directory will be created at :attr:`output_path` containing the partition ID
    mapping files named "<node-type>.txt" (e.g. "author.txt", "paper.txt" and
    "institution.txt" for OGB-MAG240M).  Each file contains one line per node representing
    the partition ID the node belongs to.
30
    In addition, metadata which includes version, number of partitions is dumped.
31
32
33
    """
    with setdir(output_path):
        _random_partition(metadata, num_parts)
34
35
        part_meta = PartitionMeta(version='1.0.0', num_parts=num_parts, algo_name='random')
        dump_partition_meta(part_meta, 'partition_meta.json')
36
37
38
39
40
41

# Run with PYTHONPATH=${GIT_ROOT_DIR}/tools
# where ${GIT_ROOT_DIR} is the directory to the DGL git repository.
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
42
            '--in_dir', type=str, help='input directory that contains the metadata file')
43
    parser.add_argument(
44
            '--out_dir', type=str, help='output directory')
45
    parser.add_argument(
46
            '--num_partitions', type=int, help='number of partitions')
47
48
    logging.basicConfig(level='INFO')
    args = parser.parse_args()
49
    with open(os.path.join(args.in_dir, 'metadata.json')) as f:
50
51
        metadata = json.load(f)
    num_parts = args.num_partitions
52
    random_partition(metadata, num_parts, args.out_dir)