"""Copy the partitions to a cluster of machines.""" import os import stat import sys import subprocess import argparse import signal import logging import json import copy def copy_file(file_name, ip, workspace): print('copy {} to {}'.format(file_name, ip + ':' + workspace + '/')) cmd = 'rsync -e \"ssh -o StrictHostKeyChecking=no\" -arvc ' + file_name + ' ' + ip + ':' + workspace + '/' subprocess.check_call(cmd, shell = True) def exec_cmd(ip, cmd): cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\'' subprocess.check_call(cmd, shell = True) def main(): parser = argparse.ArgumentParser(description='Copy data to the servers.') parser.add_argument('--workspace', type=str, required=True, help='Path of user directory of distributed tasks. \ This is used to specify a destination location where \ data are copied to on remote machines.') parser.add_argument('--rel_data_path', type=str, required=True, help='Relative path in workspace to store the partition data.') parser.add_argument('--part_config', type=str, required=True, help='The partition config file. The path is on the local machine.') parser.add_argument('--script_folder', type=str, required=True, help='The folder contains all the user code scripts.') parser.add_argument('--ip_config', type=str, required=True, help='The file of IP configuration for servers. \ The path is on the local machine.') args = parser.parse_args() hosts = [] with open(args.ip_config) as f: for line in f: res = line.strip().split(' ') ip = res[0] hosts.append(ip) # We need to update the partition config file so that the paths are relative to # the workspace in the remote machines. with open(args.part_config) as conf_f: part_metadata = json.load(conf_f) tmp_part_metadata = copy.deepcopy(part_metadata) num_parts = part_metadata['num_parts'] assert num_parts == len(hosts), \ 'The number of partitions needs to be the same as the number of hosts.' graph_name = part_metadata['graph_name'] node_map = part_metadata['node_map'] edge_map = part_metadata['edge_map'] if not isinstance(node_map, list): assert node_map[-4:] == '.npy', 'node map should be stored in a NumPy array.' tmp_part_metadata['node_map'] = '{}/{}/node_map.npy'.format(args.workspace, args.rel_data_path) if not isinstance(edge_map, list): assert edge_map[-4:] == '.npy', 'edge map should be stored in a NumPy array.' tmp_part_metadata['edge_map'] = '{}/{}/edge_map.npy'.format(args.workspace, args.rel_data_path) for part_id in range(num_parts): part_files = tmp_part_metadata['part-{}'.format(part_id)] part_files['edge_feats'] = '{}/part{}/edge_feat.dgl'.format(args.rel_data_path, part_id) part_files['node_feats'] = '{}/part{}/node_feat.dgl'.format(args.rel_data_path, part_id) part_files['part_graph'] = '{}/part{}/graph.dgl'.format(args.rel_data_path, part_id) tmp_part_config = '/tmp/{}.json'.format(graph_name) with open(tmp_part_config, 'w') as outfile: json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4) # Copy ip config. for part_id, ip in enumerate(hosts): remote_path = '{}/{}'.format(args.workspace, args.rel_data_path) exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) copy_file(args.ip_config, ip, args.workspace) copy_file(tmp_part_config, ip, '{}/{}'.format(args.workspace, args.rel_data_path)) node_map = part_metadata['node_map'] edge_map = part_metadata['edge_map'] if not isinstance(node_map, list): copy_file(node_map, ip, tmp_part_metadata['node_map']) if not isinstance(edge_map, list): copy_file(edge_map, ip, tmp_part_metadata['edge_map']) remote_path = '{}/{}/part{}'.format(args.workspace, args.rel_data_path, part_id) exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) part_files = part_metadata['part-{}'.format(part_id)] copy_file(part_files['node_feats'], ip, remote_path) copy_file(part_files['edge_feats'], ip, remote_path) copy_file(part_files['part_graph'], ip, remote_path) # copy script folder copy_file(args.script_folder, ip, args.workspace) def signal_handler(signal, frame): logging.info('Stop copying') sys.exit(0) if __name__ == '__main__': fmt = '%(asctime)s %(levelname)s %(message)s' logging.basicConfig(format=fmt, level=logging.INFO) signal.signal(signal.SIGINT, signal_handler) main()