"vscode:/vscode.git/clone" did not exist on "f6e2f05b16c1d13f0e017450cad975244099de80"
copy_files.py 4.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
"""Copy the partitions to a cluster of machines."""
import os
import stat
import sys
import subprocess
import argparse
import signal
import logging
import json
import copy

12
def copy_file(file_name, ip, workspace, param=''):
13
    print('copy {} to {}'.format(file_name, ip + ':' + workspace + '/'))
14
    cmd = 'scp ' + param + ' ' + file_name + ' ' + ip + ':' + workspace + '/'
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
    subprocess.check_call(cmd, shell = True)

def exec_cmd(ip, cmd):
    cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\''
    subprocess.check_call(cmd, shell = True)

def main():
    parser = argparse.ArgumentParser(description='Copy data to the servers.')
    parser.add_argument('--workspace', type=str, required=True,
                        help='Path of user directory of distributed tasks. \
                        This is used to specify a destination location where \
                        data are copied to on remote machines.')
    parser.add_argument('--rel_data_path', type=str, required=True,
                        help='Relative path in workspace to store the partition data.')
    parser.add_argument('--part_config', type=str, required=True,
                        help='The partition config file. The path is on the local machine.')
31
32
    parser.add_argument('--script_folder', type=str, required=True,
                        help='The folder contains all the user code scripts.')
33
34
35
36
37
38
39
40
    parser.add_argument('--ip_config', type=str, required=True,
                        help='The file of IP configuration for servers. \
                        The path is on the local machine.')
    args = parser.parse_args()

    hosts = []
    with open(args.ip_config) as f:
        for line in f:
41
42
            res = line.strip().split(' ')
            ip = res[0]
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
            hosts.append(ip)
    
    # We need to update the partition config file so that the paths are relative to
    # the workspace in the remote machines.
    with open(args.part_config) as conf_f:
        part_metadata = json.load(conf_f)
        tmp_part_metadata = copy.deepcopy(part_metadata)
        num_parts = part_metadata['num_parts']
        assert num_parts == len(hosts), \
                'The number of partitions needs to be the same as the number of hosts.'
        graph_name = part_metadata['graph_name']
        node_map = part_metadata['node_map']
        edge_map = part_metadata['edge_map']
        if not isinstance(node_map, list):
            assert node_map[-4:] == '.npy', 'node map should be stored in a NumPy array.'
            tmp_part_metadata['node_map'] = '{}/{}/node_map.npy'.format(args.workspace,
                                                                        args.rel_data_path)
        if not isinstance(edge_map, list):
            assert edge_map[-4:] == '.npy', 'edge map should be stored in a NumPy array.'
            tmp_part_metadata['edge_map'] = '{}/{}/edge_map.npy'.format(args.workspace,
                                                                        args.rel_data_path)

        for part_id in range(num_parts):
            part_files = tmp_part_metadata['part-{}'.format(part_id)]
            part_files['edge_feats'] = '{}/part{}/edge_feat.dgl'.format(args.rel_data_path, part_id)
            part_files['node_feats'] = '{}/part{}/node_feat.dgl'.format(args.rel_data_path, part_id)
            part_files['part_graph'] = '{}/part{}/graph.dgl'.format(args.rel_data_path, part_id)
    tmp_part_config = '/tmp/{}.json'.format(graph_name)
    with open(tmp_part_config, 'w') as outfile:
        json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4)

    # Copy ip config.
    for part_id, ip in enumerate(hosts):
        remote_path = '{}/{}'.format(args.workspace, args.rel_data_path)
        exec_cmd(ip, 'mkdir -p {}'.format(remote_path))

        copy_file(args.ip_config, ip, args.workspace)
        copy_file(tmp_part_config, ip, '{}/{}'.format(args.workspace, args.rel_data_path))
        node_map = part_metadata['node_map']
        edge_map = part_metadata['edge_map']
        if not isinstance(node_map, list):
            copy_file(node_map, ip, tmp_part_metadata['node_map'])
        if not isinstance(edge_map, list):
            copy_file(edge_map, ip, tmp_part_metadata['edge_map'])
        remote_path = '{}/{}/part{}'.format(args.workspace, args.rel_data_path, part_id)
        exec_cmd(ip, 'mkdir -p {}'.format(remote_path))

        part_files = part_metadata['part-{}'.format(part_id)]
        copy_file(part_files['node_feats'], ip, remote_path)
        copy_file(part_files['edge_feats'], ip, remote_path)
        copy_file(part_files['part_graph'], ip, remote_path)
94
        # copy script folder
95
        copy_file(args.script_folder, ip, args.workspace, '-r')
96
97
98
99
100
101
102
103
104
105
106


def signal_handler(signal, frame):
    logging.info('Stop copying')
    sys.exit(0)

if __name__ == '__main__':
    fmt = '%(asctime)s %(levelname)s %(message)s'
    logging.basicConfig(format=fmt, level=logging.INFO)
    signal.signal(signal.SIGINT, signal_handler)
    main()