copy_files.py 5.1 KB
Newer Older
1
"""Copy the partitions to a cluster of machines."""
2
3
4
5
import argparse
import copy
import json
import logging
6
import os
7
import signal
8
9
import stat
import subprocess
10
11
12
13
14
15
16
import sys


def copy_file(file_name, ip, workspace, param=""):
    print("copy {} to {}".format(file_name, ip + ":" + workspace + "/"))
    cmd = "scp " + param + " " + file_name + " " + ip + ":" + workspace + "/"
    subprocess.check_call(cmd, shell=True)
17
18
19


def exec_cmd(ip, cmd):
20
21
22
    cmd = "ssh -o StrictHostKeyChecking=no " + ip + " '" + cmd + "'"
    subprocess.check_call(cmd, shell=True)

23
24

def main():
25
26
27
28
29
30
    parser = argparse.ArgumentParser(description="Copy data to the servers.")
    parser.add_argument(
        "--workspace",
        type=str,
        required=True,
        help="Path of user directory of distributed tasks. \
31
                        This is used to specify a destination location where \
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
                        data are copied to on remote machines.",
    )
    parser.add_argument(
        "--rel_data_path",
        type=str,
        required=True,
        help="Relative path in workspace to store the partition data.",
    )
    parser.add_argument(
        "--part_config",
        type=str,
        required=True,
        help="The partition config file. The path is on the local machine.",
    )
    parser.add_argument(
        "--script_folder",
        type=str,
        required=True,
        help="The folder contains all the user code scripts.",
    )
    parser.add_argument(
        "--ip_config",
        type=str,
        required=True,
        help="The file of IP configuration for servers. \
                        The path is on the local machine.",
    )
59
60
61
62
63
    args = parser.parse_args()

    hosts = []
    with open(args.ip_config) as f:
        for line in f:
64
            res = line.strip().split(" ")
65
            ip = res[0]
66
            hosts.append(ip)
67

68
69
70
71
72
    # We need to update the partition config file so that the paths are relative to
    # the workspace in the remote machines.
    with open(args.part_config) as conf_f:
        part_metadata = json.load(conf_f)
        tmp_part_metadata = copy.deepcopy(part_metadata)
73
74
75
76
77
78
79
        num_parts = part_metadata["num_parts"]
        assert num_parts == len(
            hosts
        ), "The number of partitions needs to be the same as the number of hosts."
        graph_name = part_metadata["graph_name"]
        node_map = part_metadata["node_map"]
        edge_map = part_metadata["edge_map"]
80
        if not isinstance(node_map, dict):
81
82
83
84
85
86
            assert (
                node_map[-4:] == ".npy"
            ), "node map should be stored in a NumPy array."
            tmp_part_metadata["node_map"] = "{}/{}/node_map.npy".format(
                args.workspace, args.rel_data_path
            )
87
        if not isinstance(edge_map, dict):
88
89
90
91
92
93
            assert (
                edge_map[-4:] == ".npy"
            ), "edge map should be stored in a NumPy array."
            tmp_part_metadata["edge_map"] = "{}/{}/edge_map.npy".format(
                args.workspace, args.rel_data_path
            )
94
95

        for part_id in range(num_parts):
96
97
98
99
100
101
102
103
104
105
106
107
            part_files = tmp_part_metadata["part-{}".format(part_id)]
            part_files["edge_feats"] = "{}/part{}/edge_feat.dgl".format(
                args.rel_data_path, part_id
            )
            part_files["node_feats"] = "{}/part{}/node_feat.dgl".format(
                args.rel_data_path, part_id
            )
            part_files["part_graph"] = "{}/part{}/graph.dgl".format(
                args.rel_data_path, part_id
            )
    tmp_part_config = "/tmp/{}.json".format(graph_name)
    with open(tmp_part_config, "w") as outfile:
108
109
110
111
        json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4)

    # Copy ip config.
    for part_id, ip in enumerate(hosts):
112
113
        remote_path = "{}/{}".format(args.workspace, args.rel_data_path)
        exec_cmd(ip, "mkdir -p {}".format(remote_path))
114
115

        copy_file(args.ip_config, ip, args.workspace)
116
117
118
119
120
121
122
        copy_file(
            tmp_part_config,
            ip,
            "{}/{}".format(args.workspace, args.rel_data_path),
        )
        node_map = part_metadata["node_map"]
        edge_map = part_metadata["edge_map"]
123
        if not isinstance(node_map, dict):
124
            copy_file(node_map, ip, tmp_part_metadata["node_map"])
125
        if not isinstance(edge_map, dict):
126
127
128
129
130
131
132
133
134
135
            copy_file(edge_map, ip, tmp_part_metadata["edge_map"])
        remote_path = "{}/{}/part{}".format(
            args.workspace, args.rel_data_path, part_id
        )
        exec_cmd(ip, "mkdir -p {}".format(remote_path))

        part_files = part_metadata["part-{}".format(part_id)]
        copy_file(part_files["node_feats"], ip, remote_path)
        copy_file(part_files["edge_feats"], ip, remote_path)
        copy_file(part_files["part_graph"], ip, remote_path)
136
        # copy script folder
137
        copy_file(args.script_folder, ip, args.workspace, "-r")
138
139
140


def signal_handler(signal, frame):
141
    logging.info("Stop copying")
142
143
    sys.exit(0)

144
145
146

if __name__ == "__main__":
    fmt = "%(asctime)s %(levelname)s %(message)s"
147
148
149
    logging.basicConfig(format=fmt, level=logging.INFO)
    signal.signal(signal.SIGINT, signal_handler)
    main()