dispatch_data.py 4.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""Launching distributed graph partitioning pipeline """
import os
import sys
import argparse
import logging
import json

INSTALL_DIR = os.path.abspath(os.path.join(__file__, '..'))
LAUNCH_SCRIPT = "distgraphlaunch.py"
PIPELINE_SCRIPT = "distpartitioning/data_proc_pipeline.py"

UDF_WORLD_SIZE = "world-size"
UDF_PART_DIR = "partitions-dir"
UDF_INPUT_DIR = "input-dir"
UDF_GRAPH_NAME = "graph-name"
UDF_SCHEMA = "schema"
UDF_NUM_PARTS = "num-parts"
UDF_OUT_DIR = "output"

LARG_PROCS_MACHINE = "num_proc_per_machine"
LARG_IPCONF = "ip_config"
LARG_MASTER_PORT = "master_port"
23
LARG_SSH_PORT = "ssh_port"
24

25
def get_launch_cmd(args) -> str:
26
    cmd = sys.executable + " " + os.path.join(INSTALL_DIR, LAUNCH_SCRIPT)
27
    cmd = f"{cmd} --{LARG_SSH_PORT} {args.ssh_port} "
28
29
30
31
32
33
34
35
36
37
    cmd = f"{cmd} --{LARG_PROCS_MACHINE} 1 "
    cmd = f"{cmd} --{LARG_IPCONF} {args.ip_config} "
    cmd = f"{cmd} --{LARG_MASTER_PORT} {args.master_port} "

    return cmd


def submit_jobs(args) -> str:
    wrapper_command = os.path.join(INSTALL_DIR, LAUNCH_SCRIPT)

38
39
40
    #read the json file and get the remaining argument here.
    schema_path = "metadata.json"
    with open(os.path.join(args.in_dir, schema_path)) as schema:
41
42
43
44
45
46
47
        schema_map = json.load(schema)

    num_parts = len(schema_map["num_nodes_per_chunk"][0])
    graph_name = schema_map["graph_name"]

    argslist = ""
    argslist += "--world-size {} ".format(num_parts)
48
49
    argslist += "--partitions-dir {} ".format(os.path.abspath(args.partitions_dir))
    argslist += "--input-dir {} ".format(os.path.abspath(args.in_dir))
50
51
52
    argslist += "--graph-name {} ".format(graph_name)
    argslist += "--schema {} ".format(schema_path)
    argslist += "--num-parts {} ".format(num_parts)
53
    argslist += "--output {} ".format(os.path.abspath(args.out_dir))
54
    argslist += "--process-group-timeout {} ".format(args.process_group_timeout)
55
    argslist += "--log-level {} ".format(args.log_level)
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    # (BarclayII) Is it safe to assume all the workers have the Python executable at the same path?
    pipeline_cmd = os.path.join(INSTALL_DIR, PIPELINE_SCRIPT)
    udf_cmd = f"{args.python_path} {pipeline_cmd} {argslist}"

    launch_cmd = get_launch_cmd(args)
    launch_cmd += '\"'+udf_cmd+'\"'

    print(launch_cmd)
    os.system(launch_cmd)

def main():
    parser = argparse.ArgumentParser(description='Dispatch edge index and data to partitions', formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--in-dir', type=str, help='Location of the input directory where the dataset is located')
    parser.add_argument('--partitions-dir', type=str, help='Location of the partition-id mapping files which define node-ids and their respective partition-ids, relative to the input directory')
    parser.add_argument('--out-dir', type=str, help='Location of the output directory where the graph partitions will be created by this pipeline')
    parser.add_argument('--ip-config', type=str, help='File location of IP configuration for server processes')
    parser.add_argument('--master-port', type=int, default=12345, help='port used by gloo group to create randezvous point')
75
    parser.add_argument('--log-level', type=str, default="info", help='To enable log level for debugging purposes. Available options: (Critical, Error, Warning, Info, Debug, Notset)')
76
    parser.add_argument('--python-path', type=str, default=sys.executable, help='Path to the Python executable on all workers')
77
    parser.add_argument('--ssh-port', type=int, default=22, help='SSH Port.') 
78
79
    parser.add_argument('--process-group-timeout', type=int, default=1800,
                        help='timeout[seconds] for operations executed against the process group')
80
81
82
83
84
85

    args, udf_command = parser.parse_known_args()

    assert os.path.isdir(args.in_dir)
    assert os.path.isdir(args.partitions_dir)
    assert os.path.isfile(args.ip_config)
86
    assert isinstance(args.log_level, str)
87
88
89
90
91
92
93
94
95
    assert isinstance(args.master_port, int)

    tokens = sys.executable.split(os.sep)
    submit_jobs(args)

if __name__ == '__main__':
    fmt = '%(asctime)s %(levelname)s %(message)s'
    logging.basicConfig(format=fmt, level=logging.INFO)
    main()