Unverified Commit 6ff24a5e authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #143 from Microsoft/master

merge master
parents 5e777d2f c1e6098d
......@@ -71,7 +71,7 @@ jobs:
echo "TEST_IMG:$TEST_IMG"
cd test
python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) \
python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \
--nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai --exclude multi_phase
......
......@@ -38,6 +38,7 @@ pai:
image:
memoryMB: 8192
outputDir:
virtualCluster:
remote:
machineList:
- ip:
......
......@@ -34,6 +34,7 @@ Optional('multiPhase'): bool,
Optional('multiThread'): bool,
Optional('nniManagerIp'): str,
Optional('logDir'): os.path.isdir,
Optional('debug'): bool,
Optional('logLevel'): Or('trace', 'debug', 'info', 'warning', 'error', 'fatal'),
'useAnnotation': bool,
Optional('advisor'): Or({
......
......@@ -271,6 +271,8 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['tuner'] = experiment_config['tuner']
if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor']
if experiment_config.get('debug') is not None:
request_data['versionCheck'] = experiment_config.get('debug')
request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local':
......@@ -313,7 +315,6 @@ def set_experiment(experiment_config, mode, port, config_file_name):
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
'''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name)
# check packages for tuner
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
tuner_name = experiment_config['tuner']['builtinTunerName']
......@@ -440,6 +441,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
# start a new experiment
print_normal('Starting experiment...')
# set debug configuration
if args.debug is not None:
experiment_config['debug'] = args.debug
response = set_experiment(experiment_config, mode, args.port, config_file_name)
if response:
if experiment_id is None:
......
......@@ -34,7 +34,10 @@ if os.environ.get('COVERAGE_PROCESS_START'):
def nni_info(*args):
if args[0].version:
print(pkg_resources.get_distribution('nni').version)
try:
print(pkg_resources.get_distribution('nni').version)
except pkg_resources.ResolutionError as err:
print_error('Get version failed, please use `pip3 list | grep nni` to check nni version!')
else:
print('please run "nnictl {positional argument} --help" to see nnictl guidance')
......@@ -51,14 +54,14 @@ def parse_args():
parser_start = subparsers.add_parser('create', help='create a new experiment')
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_start.add_argument('--debug', '-d', action='store_true', help=' set log level to debug')
parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_start.set_defaults(func=create_experiment)
# parse resume command
parser_resume = subparsers.add_parser('resume', help='resume a new experiment')
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to resume')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set log level to debug')
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_resume.set_defaults(func=resume_experiment)
# parse update command
......
......@@ -28,6 +28,7 @@ import re
import sys
import select
from pyhdfs import HdfsClient
import pkg_resources
from .constants import HOME_DIR, LOG_DIR, NNI_PLATFORM, STDOUT_FULL_PATH, STDERR_FULL_PATH
from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal
......@@ -103,6 +104,28 @@ def main_loop(args):
def trial_keeper_help_info(*args):
print('please run --help to see guidance')
def check_version(args):
try:
trial_keeper_version = pkg_resources.get_distribution('nni').version
except pkg_resources.ResolutionError as err:
#package nni does not exist, try nni-tool package
nni_log(LogType.Warning, 'Package nni does not exist!')
try:
trial_keeper_version = pkg_resources.get_distribution('nni-tool').version
except pkg_resources.ResolutionError as err:
#package nni-tool does not exist
nni_log(LogType.Error, 'Package nni-tool does not exist!')
os._exit(1)
if not args.version:
# skip version check
nni_log(LogType.Warning, 'Skipping version check!')
elif trial_keeper_version != args.version:
nni_log(LogType.Error, 'Exit trial keeper, trial keeper version is {}, and trainingService version is {}, \
versions does not match, please check your code and image versions!'.format(trial_keeper_version, args.version))
os._exit(1)
else:
nni_log(LogType.Info, 'NNI version is {}'.format(args.version))
if __name__ == '__main__':
'''NNI Trial Keeper main function'''
PARSER = argparse.ArgumentParser()
......@@ -117,10 +140,11 @@ if __name__ == '__main__':
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs')
PARSER.add_argument('--webhdfs_path', type=str, help='the webhdfs path used in webhdfs URL')
PARSER.add_argument('--version', type=str, help='the nni version transmitted from trainingService')
args, unknown = PARSER.parse_known_args()
if args.trial_command is None:
exit(1)
check_version(args)
try:
main_loop(args)
except SystemExit as se:
......
import setuptools
setuptools.setup(
name = 'nnictl',
name = 'nni-tool',
version = '999.0.0-developing',
packages = setuptools.find_packages(exclude=['*test*']),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment