Unverified Commit 6e967f50 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support saving search space in experiment load&save command (#2886)

parent 4cf67f53
...@@ -578,6 +578,7 @@ Debug mode will disable version check function in Trialkeeper. ...@@ -578,6 +578,7 @@ Debug mode will disable version check function in Trialkeeper.
|--path, -p| True| |the file path of nni package| |--path, -p| True| |the file path of nni package|
|--codeDir, -c| True| |the path of codeDir for loaded experiment, this path will also put the code in the loaded experiment package| |--codeDir, -c| True| |the path of codeDir for loaded experiment, this path will also put the code in the loaded experiment package|
|--logDir, -l| False| |the path of logDir for loaded experiment| |--logDir, -l| False| |the path of logDir for loaded experiment|
|--searchSpacePath, -s| True| |the path of search space file for loaded experiment, this path contains file name. Default in $codeDir/search_space.json|
* Examples * Examples
......
...@@ -159,6 +159,8 @@ def parse_args(): ...@@ -159,6 +159,8 @@ def parse_args():
parser_load_experiment.add_argument('--codeDir', '-c', required=True, help='the path of codeDir for loaded experiment, \ parser_load_experiment.add_argument('--codeDir', '-c', required=True, help='the path of codeDir for loaded experiment, \
this path will also put the code in the loaded experiment package') this path will also put the code in the loaded experiment package')
parser_load_experiment.add_argument('--logDir', '-l', required=False, help='the path of logDir for loaded experiment') parser_load_experiment.add_argument('--logDir', '-l', required=False, help='the path of logDir for loaded experiment')
parser_load_experiment.add_argument('--searchSpacePath', '-s', required=False, help='the path of search space file for \
loaded experiment, this path contains file name. Default in $codeDir/search_space.json')
parser_load_experiment.set_defaults(func=load_experiment) parser_load_experiment.set_defaults(func=load_experiment)
#parse platform command #parse platform command
......
...@@ -827,7 +827,18 @@ def save_experiment(args): ...@@ -827,7 +827,18 @@ def save_experiment(args):
temp_code_dir = os.path.join(temp_root_dir, 'code') temp_code_dir = os.path.join(temp_root_dir, 'code')
shutil.copytree(nni_config.get_config('experimentConfig')['trial']['codeDir'], temp_code_dir) shutil.copytree(nni_config.get_config('experimentConfig')['trial']['codeDir'], temp_code_dir)
# Step4. Archive folder # Step4. Copy searchSpace file
search_space_path = nni_config.get_config('experimentConfig').get('searchSpacePath')
if search_space_path:
if not os.path.exists(search_space_path):
print_warning('search space %s does not exist!' % search_space_path)
else:
temp_search_space_dir = os.path.join(temp_root_dir, 'searchSpace')
os.makedirs(temp_search_space_dir, exist_ok=True)
search_space_name = os.path.basename(search_space_path)
shutil.copyfile(search_space_path, os.path.join(temp_search_space_dir, search_space_name))
# Step5. Archive folder
zip_package_name = 'nni_experiment_%s' % args.id zip_package_name = 'nni_experiment_%s' % args.id
if args.path: if args.path:
os.makedirs(args.path, exist_ok=True) os.makedirs(args.path, exist_ok=True)
...@@ -844,6 +855,9 @@ def load_experiment(args): ...@@ -844,6 +855,9 @@ def load_experiment(args):
if not os.path.exists(args.path): if not os.path.exists(args.path):
print_error('file path %s does not exist!' % args.path) print_error('file path %s does not exist!' % args.path)
exit(1) exit(1)
if args.searchSpacePath and os.path.isdir(args.searchSpacePath):
print_error('search space path should be a full path with filename, not a directory!')
exit(1)
temp_root_dir = generate_temp_dir() temp_root_dir = generate_temp_dir()
shutil.unpack_archive(package_path, temp_root_dir) shutil.unpack_archive(package_path, temp_root_dir)
print_normal('Loading...') print_normal('Loading...')
...@@ -929,7 +943,32 @@ def load_experiment(args): ...@@ -929,7 +943,32 @@ def load_experiment(args):
else: else:
shutil.copy(src_path, target_path) shutil.copy(src_path, target_path)
# Step5. Create experiment metadata # Step5. Copy searchSpace file
archive_search_space_dir = os.path.join(temp_root_dir, 'searchSpace')
if args.searchSpacePath:
target_path = os.path.expanduser(args.searchSpacePath)
else:
# set default path to codeDir
target_path = os.path.join(codeDir, 'search_space.json')
if not os.path.isabs(target_path):
target_path = os.path.join(os.getcwd(), target_path)
print_normal('Expand search space path to %s' % target_path)
nnictl_exp_config['searchSpacePath'] = target_path
# if the path already has a search space file, use the original one, otherwise use archived one
if not os.path.isfile(target_path):
if len(os.listdir(archive_search_space_dir)) == 0:
print_error('Archive file does not contain search space file!')
exit(1)
else:
for file in os.listdir(archive_search_space_dir):
source_path = os.path.join(archive_search_space_dir, file)
os.makedirs(os.path.dirname(target_path), exist_ok=True)
shutil.copyfile(source_path, target_path)
break
elif not args.searchSpacePath:
print_warning('%s exist, will not load search_space file' % target_path)
# Step6. Create experiment metadata
nni_config.set_config('experimentConfig', nnictl_exp_config) nni_config.set_config('experimentConfig', nnictl_exp_config)
experiment_config.add_experiment(experiment_id, experiment_config.add_experiment(experiment_id,
experiment_metadata.get('port'), experiment_metadata.get('port'),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment