Unverified Commit c785655e authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #207 from microsoft/master

merge master
parents 9fae194a d6b61e2f
......@@ -14,14 +14,12 @@ tuner:
className: NaiveTuner
classArgs:
optimize_mode: maximize
gpuNum: 0
assessor:
codeDir: .
classFileName: naive_assessor.py
className: NaiveAssessor
classArgs:
optimize_mode: maximize
gpuNum: 0
trial:
command: python3 naive_trial.py
codeDir: .
......
......@@ -39,11 +39,6 @@ jobs:
displayName: 'Install nni toolkit via source code'
- script: |
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==0.4.1 --user
python3 -m pip install torchvision==0.2.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow==1.12.0 --user
sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
......
......@@ -18,7 +18,7 @@ jobs:
displayName: 'generate config files'
- script: |
cd test
python config_test.py --ts local --local_gpu --exclude smac,bohb,multi_phase_batch,multi_phase_grid
python config_test.py --ts local --local_gpu --exclude smac,bohb
displayName: 'Examples and advanced features tests on local machine'
- script: |
cd test
......
......@@ -9,8 +9,8 @@ jobs:
displayName: 'Install nni toolkit via source code'
- script: |
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==0.4.1 --user
python3 -m pip install torchvision==0.2.1 --user
python3 -m pip install torch==1.2.0 --user
python3 -m pip install torchvision==0.4.0 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.12.0 --user
sudo apt-get install swig -y
......@@ -31,7 +31,7 @@ jobs:
displayName: 'Built-in tuners / assessors tests'
- script: |
cd test
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts local --local_gpu --exclude multi_phase_batch,multi_phase_grid
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts local --local_gpu
displayName: 'Examples and advanced features tests on local machine'
- script: |
cd test
......
......@@ -65,5 +65,5 @@ jobs:
python --version
python generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
python config_test.py --ts pai --exclude multi_phase,smac,bohb,multi_phase_batch,multi_phase_grid
python config_test.py --ts pai --exclude multi_phase,smac,bohb
displayName: 'Examples and advanced features tests on pai'
\ No newline at end of file
......@@ -39,11 +39,6 @@ jobs:
displayName: 'Install nni toolkit via source code'
- script: |
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==0.4.1 --user
python3 -m pip install torchvision==0.2.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.12.0 --user
sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
......@@ -76,6 +71,6 @@ jobs:
python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \
--nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai --exclude multi_phase_batch,multi_phase_grid
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName: 'integration test'
......@@ -39,7 +39,7 @@ jobs:
cd test
python generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) --remote_port $(Get-Content port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
Get-Content training_service.yml
python config_test.py --ts remote --exclude cifar10,smac,bohb,multi_phase_batch,multi_phase_grid
python config_test.py --ts remote --exclude cifar10,smac,bohb
displayName: 'integration test'
- task: SSH@0
inputs:
......
......@@ -53,7 +53,7 @@ jobs:
python3 generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) \
--remote_port $(cat port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
cat training_service.yml
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts remote --exclude cifar10,multi_phase_batch,multi_phase_grid
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts remote --exclude cifar10
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName: 'integration test'
- task: SSH@0
......
......@@ -22,6 +22,7 @@
import ast
import astor
# pylint: disable=unidiomatic-typecheck
def parse_annotation_mutable_layers(code, lineno, nas_mode):
......@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
fields['optional_inputs'] = True
elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, 'Value of optional_input_size should be a number or list'
assert type(value) is ast.Num or type(value) is ast.List, \
'Value of optional_input_size should be a number or list'
optional_input_size = value
fields['optional_input_size'] = True
elif k.id == 'layer_output':
......@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
nodes.append(node)
return nodes
def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
......@@ -198,7 +201,7 @@ def convert_args_to_dict(call, with_lambda=False):
if type(arg) in [ast.Str, ast.Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
......@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node)
def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
......@@ -325,7 +327,7 @@ class Transformer(ast.NodeTransformer):
call_node.args.insert(0, ast.Str(s=self.nas_mode))
return expr
if string.startswith('@nni.report_intermediate_result') \
if string.startswith('@nni.report_intermediate_result') \
or string.startswith('@nni.report_final_result') \
or string.startswith('@nni.get_next_parameter'):
return parse_annotation(string[1:]) # expand annotation string to code
......@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function')
def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
......
......@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer):
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
}
def visit_Call(self, node): # pylint: disable=invalid-name
self.generic_visit(node)
......@@ -108,7 +107,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
else:
# arguments of other functions must be literal number
assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
'Smart parameter\'s arguments must be number literals'
'Smart parameter\'s arguments must be number literals'
args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args]
key = self.module_name + '/' + name + '/' + func
......
......@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning
para_cfg = None
prefix_name = None
def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes
......@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno):
nodes.append(node)
return nodes
def parse_annotation(code):
"""Parse an annotation string.
Return an AST Expr node.
......@@ -182,7 +184,7 @@ def convert_args_to_dict(call, with_lambda=False):
if type(arg) in [ast.Str, ast.Num]:
arg_value = arg
else:
# if arg is not a string or a number, we use its source code as the key
# if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg
......@@ -217,7 +219,7 @@ def test_variable_equal(node1, node2):
if len(node1) != len(node2):
return False
return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2))
return node1 == node2
......@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node)
def _visit_string(self, node):
string = node.value.s
if string.startswith('@nni.'):
......@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer):
return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code."
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. " \
"Please remove this line in the trial code."
print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:])
arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno)
......@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function')
def _visit_children(self, node):
self.stack.append(None)
self.generic_visit(node)
......
......@@ -39,17 +39,18 @@ class AnnotationTestCase(TestCase):
shutil.rmtree('_generated')
def test_search_space_generator(self):
search_space = generate_search_space('testcase/annotated')
shutil.copytree('testcase/annotated', '_generated/annotated')
search_space = generate_search_space('_generated/annotated')
with open('testcase/searchspace.json') as f:
self.assertEqual(search_space, json.load(f))
def test_code_generator(self):
code_dir = expand_annotations('testcase/usercode', '_generated', nas_mode='classic_mode')
self.assertEqual(code_dir, '_generated')
self._assert_source_equal('testcase/annotated/nas.py', '_generated/nas.py')
self._assert_source_equal('testcase/annotated/mnist.py', '_generated/mnist.py')
self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/dir/simple.py')
with open('testcase/usercode/nonpy.txt') as src, open('_generated/nonpy.txt') as dst:
code_dir = expand_annotations('testcase/usercode', '_generated/usercode', nas_mode='classic_mode')
self.assertEqual(code_dir, '_generated/usercode')
self._assert_source_equal('testcase/annotated/nas.py', '_generated/usercode/nas.py')
self._assert_source_equal('testcase/annotated/mnist.py', '_generated/usercode/mnist.py')
self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/usercode/dir/simple.py')
with open('testcase/usercode/nonpy.txt') as src, open('_generated/usercode/nonpy.txt') as dst:
assert src.read() == dst.read()
def test_annotation_detecting(self):
......
......@@ -76,7 +76,7 @@ tuner_schema_dict = {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
('Evolution'): {
'builtinTunerName': setChoice('builtinTunerName', 'Evolution'),
......@@ -85,12 +85,12 @@ tuner_schema_dict = {
Optional('population_size'): setNumberRange('population_size', int, 0, 99999),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
('BatchTuner', 'GridSearch', 'Random'): {
'builtinTunerName': setChoice('builtinTunerName', 'BatchTuner', 'GridSearch', 'Random'),
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'TPE': {
'builtinTunerName': 'TPE',
......@@ -100,7 +100,7 @@ tuner_schema_dict = {
Optional('constant_liar_type'): setChoice('constant_liar_type', 'min', 'max', 'mean')
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'NetworkMorphism': {
'builtinTunerName': 'NetworkMorphism',
......@@ -112,7 +112,7 @@ tuner_schema_dict = {
Optional('n_output_node'): setType('n_output_node', int),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'MetisTuner': {
'builtinTunerName': 'MetisTuner',
......@@ -124,7 +124,7 @@ tuner_schema_dict = {
Optional('cold_start_num'): setType('cold_start_num', int),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'GPTuner': {
'builtinTunerName': 'GPTuner',
......@@ -140,7 +140,25 @@ tuner_schema_dict = {
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'PPOTuner': {
'builtinTunerName': 'PPOTuner',
'classArgs': {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('trials_per_update'): setNumberRange('trials_per_update', int, 0, 99999),
Optional('epochs_per_update'): setNumberRange('epochs_per_update', int, 0, 99999),
Optional('minibatch_size'): setNumberRange('minibatch_size', int, 0, 99999),
Optional('ent_coef'): setType('ent_coef', float),
Optional('lr'): setType('lr', float),
Optional('vf_coef'): setType('vf_coef', float),
Optional('max_grad_norm'): setType('max_grad_norm', float),
Optional('gamma'): setType('gamma', float),
Optional('lam'): setType('lam', float),
Optional('cliprange'): setType('cliprange', float),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'customized': {
'codeDir': setPathCheck('codeDir'),
......@@ -148,7 +166,7 @@ tuner_schema_dict = {
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}
}
......@@ -160,7 +178,7 @@ advisor_schema_dict = {
Optional('R'): setType('R', int),
Optional('eta'): setType('eta', int)
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'BOHB':{
'builtinAdvisorName': Or('BOHB'),
......@@ -176,14 +194,14 @@ advisor_schema_dict = {
Optional('bandwidth_factor'): setNumberRange('bandwidth_factor', float, 0, 9999),
Optional('min_bandwidth'): setNumberRange('min_bandwidth', float, 0, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'customized':{
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}
}
......@@ -194,7 +212,6 @@ assessor_schema_dict = {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('start_step'): setNumberRange('start_step', int, 0, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'Curvefitting': {
'builtinAssessorName': 'Curvefitting',
......@@ -205,14 +222,12 @@ assessor_schema_dict = {
Optional('threshold'): setNumberRange('threshold', float, 0, 9999),
Optional('gap'): setNumberRange('gap', int, 1, 9999),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'customized': {
'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str),
'className': setType('className', str),
Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999)
}
}
......
......@@ -80,7 +80,8 @@ TRIAL_MONITOR_TAIL = '----------------------------------------------------------
PACKAGE_REQUIREMENTS = {
'SMAC': 'smac_tuner',
'BOHB': 'bohb_advisor'
'BOHB': 'bohb_advisor',
'PPOTuner': 'ppo_tuner'
}
TUNERS_SUPPORTING_IMPORT_DATA = {
......
......@@ -118,12 +118,17 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
node_command = 'node'
if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform, '--start_mode', mode]
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform]
if mode == 'view':
cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true']
else:
cmds += ['--start_mode', mode]
if log_dir is not None:
cmds += ['--log_dir', log_dir]
if log_level is not None:
cmds += ['--log_level', log_level]
if mode == 'resume':
if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id]
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
......@@ -156,7 +161,6 @@ def set_trial_config(experiment_config, port, config_file_name):
def set_local_config(experiment_config, port, config_file_name):
'''set local configuration'''
#set machine_list
request_data = dict()
if experiment_config.get('localConfig'):
request_data['local_config'] = experiment_config['localConfig']
......@@ -177,7 +181,7 @@ def set_local_config(experiment_config, port, config_file_name):
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
return set_trial_config(experiment_config, port, config_file_name)
return set_trial_config(experiment_config, port, config_file_name), None
def set_remote_config(experiment_config, port, config_file_name):
'''Call setClusterMetadata to pass trial'''
......@@ -296,10 +300,20 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['multiThread'] = experiment_config.get('multiThread')
if experiment_config.get('advisor'):
request_data['advisor'] = experiment_config['advisor']
if request_data['advisor'].get('gpuNum'):
print_error('gpuNum is deprecated, please use gpuIndices instead.')
if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
else:
request_data['tuner'] = experiment_config['tuner']
if request_data['tuner'].get('gpuNum'):
print_error('gpuNum is deprecated, please use gpuIndices instead.')
if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor']
if request_data['assessor'].get('gpuNum'):
print_error('gpuNum is deprecated, please remove it from your config file.')
#debug mode should disable version check
if experiment_config.get('debug') is not None:
request_data['versionCheck'] = not experiment_config.get('debug')
......@@ -335,7 +349,6 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
if check_response(response):
return response
......@@ -347,6 +360,33 @@ def set_experiment(experiment_config, mode, port, config_file_name):
print_error('Setting experiment error, error message is {}'.format(response.text))
return None
def set_platform_config(platform, experiment_config, port, config_file_name, rest_process):
'''call set_cluster_metadata for specific platform'''
print_normal('Setting {0} config...'.format(platform))
config_result, err_msg = None, None
if platform == 'local':
config_result, err_msg = set_local_config(experiment_config, port, config_file_name)
elif platform == 'remote':
config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
elif platform == 'pai':
config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
elif platform == 'kubeflow':
config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
elif platform == 'frameworkcontroller':
config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
else:
raise Exception(ERROR_INFO % 'Unsupported platform!')
exit(1)
if config_result:
print_normal('Successfully set {0} config!'.format(platform))
else:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
'''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name)
......@@ -371,8 +411,10 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug'
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created
if mode != 'view':
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug'
# start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
......@@ -406,83 +448,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
# set remote config
if experiment_config['trainingServicePlatform'] == 'remote':
print_normal('Setting remote config...')
config_result, err_msg = set_remote_config(experiment_config, args.port, config_file_name)
if config_result:
print_normal('Successfully set remote config!')
else:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
# set local config
if experiment_config['trainingServicePlatform'] == 'local':
print_normal('Setting local config...')
if set_local_config(experiment_config, args.port, config_file_name):
print_normal('Successfully set local config!')
else:
print_error('Set local config failed!')
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
#set pai config
if experiment_config['trainingServicePlatform'] == 'pai':
print_normal('Setting pai config...')
config_result, err_msg = set_pai_config(experiment_config, args.port, config_file_name)
if config_result:
print_normal('Successfully set pai config!')
else:
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
#set kubeflow config
if experiment_config['trainingServicePlatform'] == 'kubeflow':
print_normal('Setting kubeflow config...')
config_result, err_msg = set_kubeflow_config(experiment_config, args.port, config_file_name)
if config_result:
print_normal('Successfully set kubeflow config!')
else:
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
#set frameworkcontroller config
if experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
print_normal('Setting frameworkcontroller config...')
config_result, err_msg = set_frameworkcontroller_config(experiment_config, args.port, config_file_name)
if config_result:
print_normal('Successfully set frameworkcontroller config!')
else:
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
if mode != 'view':
# set platform configuration
set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port, config_file_name, rest_process)
# start a new experiment
print_normal('Starting experiment...')
# set debug configuration
if experiment_config.get('debug') is None:
if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug
response = set_experiment(experiment_config, mode, args.port, config_file_name)
if response:
......@@ -509,8 +482,23 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
def resume_experiment(args):
'''resume an experiment'''
def create_experiment(args):
'''start a new experiment'''
config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(config_file_name)
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
validate_all_content(experiment_config, config_path)
nni_config.set_config('experimentConfig', experiment_config)
launch_experiment(args, experiment_config, 'new', config_file_name)
nni_config.set_config('restServerPort', args.port)
def manage_stopped_experiment(args, mode):
'''view a stopped experiment'''
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
......@@ -518,38 +506,31 @@ def resume_experiment(args):
experiment_endTime = None
#find the latest stopped experiment
if not args.id:
print_error('Please set experiment id! \nYou could use \'nnictl resume {id}\' to resume a stopped experiment!\n' \
'You could use \'nnictl experiment list --all\' to show all experiments!')
print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \
'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
exit(1)
else:
if experiment_dict.get(args.id) is None:
print_error('Id %s not exist!' % args.id)
exit(1)
if experiment_dict[args.id]['status'] != 'STOPPED':
print_error('Only stopped experiments can be resumed!')
print_error('Only stopped experiments can be {0}ed!'.format(mode))
exit(1)
experiment_id = args.id
print_normal('Resuming experiment %s...' % experiment_id)
print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_dict[experiment_id]['fileName'])
experiment_config = nni_config.get_config('experimentConfig')
experiment_id = nni_config.get_config('experimentId')
new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
new_nni_config = Config(new_config_file_name)
new_nni_config.set_config('experimentConfig', experiment_config)
launch_experiment(args, experiment_config, 'resume', new_config_file_name, experiment_id)
launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
new_nni_config.set_config('restServerPort', args.port)
def create_experiment(args):
'''start a new experiment'''
config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(config_file_name)
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
validate_all_content(experiment_config, config_path)
def view_experiment(args):
'''view a stopped experiment'''
manage_stopped_experiment(args, 'view')
nni_config.set_config('experimentConfig', experiment_config)
launch_experiment(args, experiment_config, 'new', config_file_name)
nni_config.set_config('restServerPort', args.port)
def resume_experiment(args):
'''resume an experiment'''
manage_stopped_experiment(args, 'resume')
\ No newline at end of file
......@@ -21,7 +21,7 @@
import argparse
import pkg_resources
from .launcher import create_experiment, resume_experiment
from .launcher import create_experiment, resume_experiment, view_experiment
from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data
from .nnictl_utils import *
from .package_management import *
......@@ -66,6 +66,12 @@ def parse_args():
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_resume.set_defaults(func=resume_experiment)
# parse view command
parser_resume = subparsers.add_parser('view', help='view a stopped experiment')
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to view')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_resume.set_defaults(func=view_experiment)
# parse update command
parser_updater = subparsers.add_parser('update', help='update the experiment')
#add subparsers for parser_updater
......
......@@ -351,6 +351,7 @@ def log_stderr(args):
def log_trial(args):
''''get trial log path'''
trial_id_path_dict = {}
trial_id_list = []
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
......@@ -363,23 +364,27 @@ def log_trial(args):
if response and check_response(response):
content = json.loads(response.text)
for trial in content:
trial_id_path_dict[trial['id']] = trial['logPath']
trial_id_list.append(trial.get('id'))
if trial.get('logPath'):
trial_id_path_dict[trial.get('id')] = trial['logPath']
else:
print_error('Restful server is not running...')
exit(1)
if args.id:
if args.trial_id:
if trial_id_path_dict.get(args.trial_id):
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
else:
print_error('trial id is not valid.')
exit(1)
if args.trial_id:
if args.trial_id not in trial_id_list:
print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id))
exit(1)
if trial_id_path_dict.get(args.trial_id):
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
else:
print_error('please specific the trial id.')
print_error('Log path is not available yet, please wait...')
exit(1)
else:
print_normal('All of trial log info:')
for key in trial_id_path_dict:
print('id:' + key + ' path:' + trial_id_path_dict[key])
print_normal('id:' + key + ' path:' + trial_id_path_dict[key])
if not trial_id_path_dict:
print_normal('None')
def get_config(args):
'''get config info'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment