data_proc_pipeline.py 3.25 KB
Newer Older
1
import argparse
2
import logging
3
import os
4
import platform
5
6
7
8
9

import numpy as np
import torch.multiprocessing as mp

from data_shuffle import multi_machine_run, single_machine_run
10

11
12
13

def log_params(params):
    """Print all the command line arguments for debugging purposes.
14
15
16
17
18
19

    Parameters:
    -----------
    params: argparse object
        Argument Parser structure listing all the pre-defined parameters
    """
20
21
22
23
24
25
    print("Input Dir: ", params.input_dir)
    print("Graph Name: ", params.graph_name)
    print("Schema File: ", params.schema)
    print("No. partitions: ", params.num_parts)
    print("Output Dir: ", params.output)
    print("WorldSize: ", params.world_size)
26
    print("Metis partitions: ", params.partitions_dir)
27

28
29

if __name__ == "__main__":
30
31
    """
    Start of execution from this point.
32
33
    Invoke the appropriate function to begin execution
    """
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    # arguments which are already needed by the existing implementation of convert_partition.py
    parser = argparse.ArgumentParser(description="Construct graph partitions")
    parser.add_argument(
        "--input-dir",
        required=True,
        type=str,
        help="The directory path that contains the partition results.",
    )
    parser.add_argument(
        "--graph-name", required=True, type=str, help="The graph name"
    )
    parser.add_argument(
        "--schema", required=True, type=str, help="The schema of the graph"
    )
    parser.add_argument(
        "--num-parts", required=True, type=int, help="The number of partitions"
    )
    parser.add_argument(
        "--output",
        required=True,
        type=str,
        help="The output directory of the partitioned results",
    )
    parser.add_argument(
        "--partitions-dir",
        help="directory of the partition-ids for each node type",
        default=None,
        type=str,
    )
    parser.add_argument(
        "--log-level",
        type=str,
        default="info",
        help="To enable log level for debugging purposes. Available options: \
68
			  (Critical, Error, Warning, Info, Debug, Notset), default value \
69
70
			  is: Info",
    )
71

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    # arguments needed for the distributed implementation
    parser.add_argument(
        "--world-size",
        help="no. of processes to spawn",
        default=1,
        type=int,
        required=True,
    )
    parser.add_argument(
        "--process-group-timeout",
        required=True,
        type=int,
        help="timeout[seconds] for operations executed against the process group "
        "(see torch.distributed.init_process_group)",
    )
    parser.add_argument(
        "--save-orig-nids",
        action="store_true",
        help="Save original node IDs into files",
    )
    parser.add_argument(
        "--save-orig-eids",
        action="store_true",
        help="Save original edge IDs into files",
    )
    parser.add_argument(
        "--graph-formats",
        default=None,
        type=str,
        help="Save partitions in specified formats.",
    )
103
104
    params = parser.parse_args()

105
    # invoke the pipeline function
106
    numeric_level = getattr(logging, params.log_level.upper(), None)
107
108
109
110
    logging.basicConfig(
        level=numeric_level,
        format=f"[{platform.node()} %(levelname)s %(asctime)s PID:%(process)d] %(message)s",
    )
111
    multi_machine_run(params)