Unverified Commit 1ace9baf authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Remind users install tuner (#1078)

parent 6623dff3
......@@ -24,9 +24,9 @@ import os
import sys
import shutil
import string
from subprocess import Popen, PIPE, call, check_output, check_call
from subprocess import Popen, PIPE, call, check_output, check_call, CalledProcessError
import tempfile
from nni.constants import ModuleName
from nni.constants import ModuleName, AdvisorModuleName
from nni_annotation import *
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
......@@ -344,13 +344,18 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
'''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name)
# check packages for tuner
package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
tuner_name = experiment_config['tuner']['builtinTunerName']
module_name = ModuleName[tuner_name]
package_name = experiment_config['tuner']['builtinTunerName']
module_name = ModuleName.get(package_name)
elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
package_name = experiment_config['advisor']['builtinAdvisorName']
module_name = AdvisorModuleName.get(package_name)
if package_name and module_name:
try:
check_call([sys.executable, '-c', 'import %s'%(module_name)])
except ModuleNotFoundError as e:
print_error('The tuner %s should be installed through nnictl'%(tuner_name))
check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=PIPE, stderr=PIPE)
except CalledProcessError as e:
print_error('%s should be installed through \'nnictl package install --name %s\''%(package_name, package_name))
exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment