Unverified Commit c785655e authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #207 from microsoft/master

merge master
parents 9fae194a d6b61e2f
...@@ -14,14 +14,12 @@ tuner: ...@@ -14,14 +14,12 @@ tuner:
className: NaiveTuner className: NaiveTuner
classArgs: classArgs:
optimize_mode: maximize optimize_mode: maximize
gpuNum: 0
assessor: assessor:
codeDir: . codeDir: .
classFileName: naive_assessor.py classFileName: naive_assessor.py
className: NaiveAssessor className: NaiveAssessor
classArgs: classArgs:
optimize_mode: maximize optimize_mode: maximize
gpuNum: 0
trial: trial:
command: python3 naive_trial.py command: python3 naive_trial.py
codeDir: . codeDir: .
......
...@@ -39,11 +39,6 @@ jobs: ...@@ -39,11 +39,6 @@ jobs:
displayName: 'Install nni toolkit via source code' displayName: 'Install nni toolkit via source code'
- script: | - script: |
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==0.4.1 --user
python3 -m pip install torchvision==0.2.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow==1.12.0 --user
sudo apt-get install swig -y sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
......
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
displayName: 'generate config files' displayName: 'generate config files'
- script: | - script: |
cd test cd test
python config_test.py --ts local --local_gpu --exclude smac,bohb,multi_phase_batch,multi_phase_grid python config_test.py --ts local --local_gpu --exclude smac,bohb
displayName: 'Examples and advanced features tests on local machine' displayName: 'Examples and advanced features tests on local machine'
- script: | - script: |
cd test cd test
......
...@@ -9,8 +9,8 @@ jobs: ...@@ -9,8 +9,8 @@ jobs:
displayName: 'Install nni toolkit via source code' displayName: 'Install nni toolkit via source code'
- script: | - script: |
python3 -m pip install scikit-learn==0.20.0 --user python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==0.4.1 --user python3 -m pip install torch==1.2.0 --user
python3 -m pip install torchvision==0.2.1 --user python3 -m pip install torchvision==0.4.0 --user
python3 -m pip install keras==2.1.6 --user python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.12.0 --user python3 -m pip install tensorflow-gpu==1.12.0 --user
sudo apt-get install swig -y sudo apt-get install swig -y
...@@ -31,7 +31,7 @@ jobs: ...@@ -31,7 +31,7 @@ jobs:
displayName: 'Built-in tuners / assessors tests' displayName: 'Built-in tuners / assessors tests'
- script: | - script: |
cd test cd test
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts local --local_gpu --exclude multi_phase_batch,multi_phase_grid PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts local --local_gpu
displayName: 'Examples and advanced features tests on local machine' displayName: 'Examples and advanced features tests on local machine'
- script: | - script: |
cd test cd test
......
...@@ -65,5 +65,5 @@ jobs: ...@@ -65,5 +65,5 @@ jobs:
python --version python --version
python generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip) python generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) --nni_docker_image $(docker_image) --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
python config_test.py --ts pai --exclude multi_phase,smac,bohb,multi_phase_batch,multi_phase_grid python config_test.py --ts pai --exclude multi_phase,smac,bohb
displayName: 'Examples and advanced features tests on pai' displayName: 'Examples and advanced features tests on pai'
\ No newline at end of file
...@@ -39,11 +39,6 @@ jobs: ...@@ -39,11 +39,6 @@ jobs:
displayName: 'Install nni toolkit via source code' displayName: 'Install nni toolkit via source code'
- script: | - script: |
python3 -m pip install scikit-learn==0.20.0 --user
python3 -m pip install torch==0.4.1 --user
python3 -m pip install torchvision==0.2.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install tensorflow-gpu==1.12.0 --user
sudo apt-get install swig -y sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
...@@ -76,6 +71,6 @@ jobs: ...@@ -76,6 +71,6 @@ jobs:
python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \ python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \
--nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip) --nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai --exclude multi_phase_batch,multi_phase_grid PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName: 'integration test' displayName: 'integration test'
...@@ -39,7 +39,7 @@ jobs: ...@@ -39,7 +39,7 @@ jobs:
cd test cd test
python generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) --remote_port $(Get-Content port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip) python generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) --remote_port $(Get-Content port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
Get-Content training_service.yml Get-Content training_service.yml
python config_test.py --ts remote --exclude cifar10,smac,bohb,multi_phase_batch,multi_phase_grid python config_test.py --ts remote --exclude cifar10,smac,bohb
displayName: 'integration test' displayName: 'integration test'
- task: SSH@0 - task: SSH@0
inputs: inputs:
......
...@@ -53,7 +53,7 @@ jobs: ...@@ -53,7 +53,7 @@ jobs:
python3 generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) \ python3 generate_ts_config.py --ts remote --remote_user $(docker_user) --remote_host $(remote_host) \
--remote_port $(cat port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip) --remote_port $(cat port) --remote_pwd $(docker_pwd) --nni_manager_ip $(nni_manager_ip)
cat training_service.yml cat training_service.yml
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts remote --exclude cifar10,multi_phase_batch,multi_phase_grid PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts remote --exclude cifar10
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName: 'integration test' displayName: 'integration test'
- task: SSH@0 - task: SSH@0
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
import ast import ast
import astor import astor
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
def parse_annotation_mutable_layers(code, lineno, nas_mode): def parse_annotation_mutable_layers(code, lineno, nas_mode):
...@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -79,7 +80,8 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
fields['optional_inputs'] = True fields['optional_inputs'] = True
elif k.id == 'optional_input_size': elif k.id == 'optional_input_size':
assert not fields['optional_input_size'], 'Duplicated field: optional_input_size' assert not fields['optional_input_size'], 'Duplicated field: optional_input_size'
assert type(value) is ast.Num or type(value) is ast.List, 'Value of optional_input_size should be a number or list' assert type(value) is ast.Num or type(value) is ast.List, \
'Value of optional_input_size should be a number or list'
optional_input_size = value optional_input_size = value
fields['optional_input_size'] = True fields['optional_input_size'] = True
elif k.id == 'layer_output': elif k.id == 'layer_output':
...@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode): ...@@ -118,6 +120,7 @@ def parse_annotation_mutable_layers(code, lineno, nas_mode):
nodes.append(node) nodes.append(node)
return nodes return nodes
def parse_annotation(code): def parse_annotation(code):
"""Parse an annotation string. """Parse an annotation string.
Return an AST Expr node. Return an AST Expr node.
...@@ -198,7 +201,7 @@ def convert_args_to_dict(call, with_lambda=False): ...@@ -198,7 +201,7 @@ def convert_args_to_dict(call, with_lambda=False):
if type(arg) in [ast.Str, ast.Num]: if type(arg) in [ast.Str, ast.Num]:
arg_value = arg arg_value = arg
else: else:
# if arg is not a string or a number, we use its source code as the key # if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"') arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value)) arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg arg = make_lambda(arg) if with_lambda else arg
...@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer): ...@@ -311,7 +314,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node) return self._visit_children(node)
def _visit_string(self, node): def _visit_string(self, node):
string = node.value.s string = node.value.s
if string.startswith('@nni.'): if string.startswith('@nni.'):
...@@ -325,7 +327,7 @@ class Transformer(ast.NodeTransformer): ...@@ -325,7 +327,7 @@ class Transformer(ast.NodeTransformer):
call_node.args.insert(0, ast.Str(s=self.nas_mode)) call_node.args.insert(0, ast.Str(s=self.nas_mode))
return expr return expr
if string.startswith('@nni.report_intermediate_result') \ if string.startswith('@nni.report_intermediate_result') \
or string.startswith('@nni.report_final_result') \ or string.startswith('@nni.report_final_result') \
or string.startswith('@nni.get_next_parameter'): or string.startswith('@nni.get_next_parameter'):
return parse_annotation(string[1:]) # expand annotation string to code return parse_annotation(string[1:]) # expand annotation string to code
...@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer): ...@@ -341,7 +343,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function') raise AssertionError('Unexpected annotation function')
def _visit_children(self, node): def _visit_children(self, node):
self.stack.append(None) self.stack.append(None)
self.generic_visit(node) self.generic_visit(node)
......
...@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -64,7 +64,6 @@ class SearchSpaceGenerator(ast.NodeTransformer):
'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n] 'optional_input_size': args[6].n if isinstance(args[6], ast.Num) else [args[6].elts[0].n, args[6].elts[1].n]
} }
def visit_Call(self, node): # pylint: disable=invalid-name def visit_Call(self, node): # pylint: disable=invalid-name
self.generic_visit(node) self.generic_visit(node)
...@@ -108,7 +107,7 @@ class SearchSpaceGenerator(ast.NodeTransformer): ...@@ -108,7 +107,7 @@ class SearchSpaceGenerator(ast.NodeTransformer):
else: else:
# arguments of other functions must be literal number # arguments of other functions must be literal number
assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \ assert all(isinstance(ast.literal_eval(astor.to_source(arg)), numbers.Real) for arg in node.args), \
'Smart parameter\'s arguments must be number literals' 'Smart parameter\'s arguments must be number literals'
args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args] args = [ast.literal_eval(astor.to_source(arg)) for arg in node.args]
key = self.module_name + '/' + name + '/' + func key = self.module_name + '/' + name + '/' + func
......
...@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning ...@@ -28,6 +28,7 @@ from nni_cmd.common_utils import print_warning
para_cfg = None para_cfg = None
prefix_name = None prefix_name = None
def parse_annotation_mutable_layers(code, lineno): def parse_annotation_mutable_layers(code, lineno):
"""Parse the string of mutable layers in annotation. """Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes Return a list of AST Expr nodes
...@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno): ...@@ -102,6 +103,7 @@ def parse_annotation_mutable_layers(code, lineno):
nodes.append(node) nodes.append(node)
return nodes return nodes
def parse_annotation(code): def parse_annotation(code):
"""Parse an annotation string. """Parse an annotation string.
Return an AST Expr node. Return an AST Expr node.
...@@ -182,7 +184,7 @@ def convert_args_to_dict(call, with_lambda=False): ...@@ -182,7 +184,7 @@ def convert_args_to_dict(call, with_lambda=False):
if type(arg) in [ast.Str, ast.Num]: if type(arg) in [ast.Str, ast.Num]:
arg_value = arg arg_value = arg
else: else:
# if arg is not a string or a number, we use its source code as the key # if arg is not a string or a number, we use its source code as the key
arg_value = astor.to_source(arg).strip('\n"') arg_value = astor.to_source(arg).strip('\n"')
arg_value = ast.Str(str(arg_value)) arg_value = ast.Str(str(arg_value))
arg = make_lambda(arg) if with_lambda else arg arg = make_lambda(arg) if with_lambda else arg
...@@ -217,7 +219,7 @@ def test_variable_equal(node1, node2): ...@@ -217,7 +219,7 @@ def test_variable_equal(node1, node2):
if len(node1) != len(node2): if len(node1) != len(node2):
return False return False
return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2)) return all(test_variable_equal(n1, n2) for n1, n2 in zip(node1, node2))
return node1 == node2 return node1 == node2
...@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer): ...@@ -294,7 +296,6 @@ class Transformer(ast.NodeTransformer):
return self._visit_children(node) return self._visit_children(node)
def _visit_string(self, node): def _visit_string(self, node):
string = node.value.s string = node.value.s
if string.startswith('@nni.'): if string.startswith('@nni.'):
...@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer): ...@@ -303,19 +304,27 @@ class Transformer(ast.NodeTransformer):
return node # not an annotation, ignore it return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'): if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code." deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. " \
"Please remove this line in the trial code."
print_warning(deprecated_message) print_warning(deprecated_message)
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='Get next parameter here...')], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Get next parameter here...')], keywords=[]))
if string.startswith('@nni.training_update'):
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='Training update here...')], keywords=[]))
if string.startswith('@nni.report_intermediate_result'): if string.startswith('@nni.report_intermediate_result'):
module = ast.parse(string[1:]) module = ast.parse(string[1:])
arg = module.body[0].value.args[0] arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_intermediate_result: '), arg], keywords=[]))
if string.startswith('@nni.report_final_result'): if string.startswith('@nni.report_final_result'):
module = ast.parse(string[1:]) module = ast.parse(string[1:])
arg = module.body[0].value.args[0] arg = module.body[0].value.args[0]
return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()), args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[])) return ast.Expr(value=ast.Call(func=ast.Name(id='print', ctx=ast.Load()),
args=[ast.Str(s='nni.report_final_result: '), arg], keywords=[]))
if string.startswith('@nni.mutable_layers'): if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno) return parse_annotation_mutable_layers(string[1:], node.lineno)
...@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer): ...@@ -327,7 +336,6 @@ class Transformer(ast.NodeTransformer):
raise AssertionError('Unexpected annotation function') raise AssertionError('Unexpected annotation function')
def _visit_children(self, node): def _visit_children(self, node):
self.stack.append(None) self.stack.append(None)
self.generic_visit(node) self.generic_visit(node)
......
...@@ -39,17 +39,18 @@ class AnnotationTestCase(TestCase): ...@@ -39,17 +39,18 @@ class AnnotationTestCase(TestCase):
shutil.rmtree('_generated') shutil.rmtree('_generated')
def test_search_space_generator(self): def test_search_space_generator(self):
search_space = generate_search_space('testcase/annotated') shutil.copytree('testcase/annotated', '_generated/annotated')
search_space = generate_search_space('_generated/annotated')
with open('testcase/searchspace.json') as f: with open('testcase/searchspace.json') as f:
self.assertEqual(search_space, json.load(f)) self.assertEqual(search_space, json.load(f))
def test_code_generator(self): def test_code_generator(self):
code_dir = expand_annotations('testcase/usercode', '_generated', nas_mode='classic_mode') code_dir = expand_annotations('testcase/usercode', '_generated/usercode', nas_mode='classic_mode')
self.assertEqual(code_dir, '_generated') self.assertEqual(code_dir, '_generated/usercode')
self._assert_source_equal('testcase/annotated/nas.py', '_generated/nas.py') self._assert_source_equal('testcase/annotated/nas.py', '_generated/usercode/nas.py')
self._assert_source_equal('testcase/annotated/mnist.py', '_generated/mnist.py') self._assert_source_equal('testcase/annotated/mnist.py', '_generated/usercode/mnist.py')
self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/dir/simple.py') self._assert_source_equal('testcase/annotated/dir/simple.py', '_generated/usercode/dir/simple.py')
with open('testcase/usercode/nonpy.txt') as src, open('_generated/nonpy.txt') as dst: with open('testcase/usercode/nonpy.txt') as src, open('_generated/usercode/nonpy.txt') as dst:
assert src.read() == dst.read() assert src.read() == dst.read()
def test_annotation_detecting(self): def test_annotation_detecting(self):
......
...@@ -76,7 +76,7 @@ tuner_schema_dict = { ...@@ -76,7 +76,7 @@ tuner_schema_dict = {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'), 'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
('Evolution'): { ('Evolution'): {
'builtinTunerName': setChoice('builtinTunerName', 'Evolution'), 'builtinTunerName': setChoice('builtinTunerName', 'Evolution'),
...@@ -85,12 +85,12 @@ tuner_schema_dict = { ...@@ -85,12 +85,12 @@ tuner_schema_dict = {
Optional('population_size'): setNumberRange('population_size', int, 0, 99999), Optional('population_size'): setNumberRange('population_size', int, 0, 99999),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
('BatchTuner', 'GridSearch', 'Random'): { ('BatchTuner', 'GridSearch', 'Random'): {
'builtinTunerName': setChoice('builtinTunerName', 'BatchTuner', 'GridSearch', 'Random'), 'builtinTunerName': setChoice('builtinTunerName', 'BatchTuner', 'GridSearch', 'Random'),
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'TPE': { 'TPE': {
'builtinTunerName': 'TPE', 'builtinTunerName': 'TPE',
...@@ -100,7 +100,7 @@ tuner_schema_dict = { ...@@ -100,7 +100,7 @@ tuner_schema_dict = {
Optional('constant_liar_type'): setChoice('constant_liar_type', 'min', 'max', 'mean') Optional('constant_liar_type'): setChoice('constant_liar_type', 'min', 'max', 'mean')
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'NetworkMorphism': { 'NetworkMorphism': {
'builtinTunerName': 'NetworkMorphism', 'builtinTunerName': 'NetworkMorphism',
...@@ -112,7 +112,7 @@ tuner_schema_dict = { ...@@ -112,7 +112,7 @@ tuner_schema_dict = {
Optional('n_output_node'): setType('n_output_node', int), Optional('n_output_node'): setType('n_output_node', int),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'MetisTuner': { 'MetisTuner': {
'builtinTunerName': 'MetisTuner', 'builtinTunerName': 'MetisTuner',
...@@ -124,7 +124,7 @@ tuner_schema_dict = { ...@@ -124,7 +124,7 @@ tuner_schema_dict = {
Optional('cold_start_num'): setType('cold_start_num', int), Optional('cold_start_num'): setType('cold_start_num', int),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'GPTuner': { 'GPTuner': {
'builtinTunerName': 'GPTuner', 'builtinTunerName': 'GPTuner',
...@@ -140,7 +140,25 @@ tuner_schema_dict = { ...@@ -140,7 +140,25 @@ tuner_schema_dict = {
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int), Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
}, },
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
},
'PPOTuner': {
'builtinTunerName': 'PPOTuner',
'classArgs': {
'optimize_mode': setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('trials_per_update'): setNumberRange('trials_per_update', int, 0, 99999),
Optional('epochs_per_update'): setNumberRange('epochs_per_update', int, 0, 99999),
Optional('minibatch_size'): setNumberRange('minibatch_size', int, 0, 99999),
Optional('ent_coef'): setType('ent_coef', float),
Optional('lr'): setType('lr', float),
Optional('vf_coef'): setType('vf_coef', float),
Optional('max_grad_norm'): setType('max_grad_norm', float),
Optional('gamma'): setType('gamma', float),
Optional('lam'): setType('lam', float),
Optional('cliprange'): setType('cliprange', float),
},
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'customized': { 'customized': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
...@@ -148,7 +166,7 @@ tuner_schema_dict = { ...@@ -148,7 +166,7 @@ tuner_schema_dict = {
'className': setType('className', str), 'className': setType('className', str),
Optional('classArgs'): dict, Optional('classArgs'): dict,
Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool), Optional('includeIntermediateResults'): setType('includeIntermediateResults', bool),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
} }
} }
...@@ -160,7 +178,7 @@ advisor_schema_dict = { ...@@ -160,7 +178,7 @@ advisor_schema_dict = {
Optional('R'): setType('R', int), Optional('R'): setType('R', int),
Optional('eta'): setType('eta', int) Optional('eta'): setType('eta', int)
}, },
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'BOHB':{ 'BOHB':{
'builtinAdvisorName': Or('BOHB'), 'builtinAdvisorName': Or('BOHB'),
...@@ -176,14 +194,14 @@ advisor_schema_dict = { ...@@ -176,14 +194,14 @@ advisor_schema_dict = {
Optional('bandwidth_factor'): setNumberRange('bandwidth_factor', float, 0, 9999), Optional('bandwidth_factor'): setNumberRange('bandwidth_factor', float, 0, 9999),
Optional('min_bandwidth'): setNumberRange('min_bandwidth', float, 0, 9999), Optional('min_bandwidth'): setNumberRange('min_bandwidth', float, 0, 9999),
}, },
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
}, },
'customized':{ 'customized':{
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str), 'classFileName': setType('classFileName', str),
'className': setType('className', str), 'className': setType('className', str),
Optional('classArgs'): dict, Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'),
} }
} }
...@@ -194,7 +212,6 @@ assessor_schema_dict = { ...@@ -194,7 +212,6 @@ assessor_schema_dict = {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'), Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('start_step'): setNumberRange('start_step', int, 0, 9999), Optional('start_step'): setNumberRange('start_step', int, 0, 9999),
}, },
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}, },
'Curvefitting': { 'Curvefitting': {
'builtinAssessorName': 'Curvefitting', 'builtinAssessorName': 'Curvefitting',
...@@ -205,14 +222,12 @@ assessor_schema_dict = { ...@@ -205,14 +222,12 @@ assessor_schema_dict = {
Optional('threshold'): setNumberRange('threshold', float, 0, 9999), Optional('threshold'): setNumberRange('threshold', float, 0, 9999),
Optional('gap'): setNumberRange('gap', int, 1, 9999), Optional('gap'): setNumberRange('gap', int, 1, 9999),
}, },
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}, },
'customized': { 'customized': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str), 'classFileName': setType('classFileName', str),
'className': setType('className', str), 'className': setType('className', str),
Optional('classArgs'): dict, Optional('classArgs'): dict,
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999)
} }
} }
......
...@@ -80,7 +80,8 @@ TRIAL_MONITOR_TAIL = '---------------------------------------------------------- ...@@ -80,7 +80,8 @@ TRIAL_MONITOR_TAIL = '----------------------------------------------------------
PACKAGE_REQUIREMENTS = { PACKAGE_REQUIREMENTS = {
'SMAC': 'smac_tuner', 'SMAC': 'smac_tuner',
'BOHB': 'bohb_advisor' 'BOHB': 'bohb_advisor',
'PPOTuner': 'ppo_tuner'
} }
TUNERS_SUPPORTING_IMPORT_DATA = { TUNERS_SUPPORTING_IMPORT_DATA = {
......
...@@ -118,12 +118,17 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -118,12 +118,17 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
node_command = 'node' node_command = 'node'
if sys.platform == 'win32': if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe') node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform, '--start_mode', mode] cmds = [node_command, entry_file, '--port', str(port), '--mode', platform]
if mode == 'view':
cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true']
else:
cmds += ['--start_mode', mode]
if log_dir is not None: if log_dir is not None:
cmds += ['--log_dir', log_dir] cmds += ['--log_dir', log_dir]
if log_level is not None: if log_level is not None:
cmds += ['--log_level', log_level] cmds += ['--log_level', log_level]
if mode == 'resume': if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id] cmds += ['--experiment_id', experiment_id]
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
...@@ -156,7 +161,6 @@ def set_trial_config(experiment_config, port, config_file_name): ...@@ -156,7 +161,6 @@ def set_trial_config(experiment_config, port, config_file_name):
def set_local_config(experiment_config, port, config_file_name): def set_local_config(experiment_config, port, config_file_name):
'''set local configuration''' '''set local configuration'''
#set machine_list
request_data = dict() request_data = dict()
if experiment_config.get('localConfig'): if experiment_config.get('localConfig'):
request_data['local_config'] = experiment_config['localConfig'] request_data['local_config'] = experiment_config['localConfig']
...@@ -177,7 +181,7 @@ def set_local_config(experiment_config, port, config_file_name): ...@@ -177,7 +181,7 @@ def set_local_config(experiment_config, port, config_file_name):
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message return False, err_message
return set_trial_config(experiment_config, port, config_file_name) return set_trial_config(experiment_config, port, config_file_name), None
def set_remote_config(experiment_config, port, config_file_name): def set_remote_config(experiment_config, port, config_file_name):
'''Call setClusterMetadata to pass trial''' '''Call setClusterMetadata to pass trial'''
...@@ -296,10 +300,20 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -296,10 +300,20 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['multiThread'] = experiment_config.get('multiThread') request_data['multiThread'] = experiment_config.get('multiThread')
if experiment_config.get('advisor'): if experiment_config.get('advisor'):
request_data['advisor'] = experiment_config['advisor'] request_data['advisor'] = experiment_config['advisor']
if request_data['advisor'].get('gpuNum'):
print_error('gpuNum is deprecated, please use gpuIndices instead.')
if request_data['advisor'].get('gpuIndices') and isinstance(request_data['advisor'].get('gpuIndices'), int):
request_data['advisor']['gpuIndices'] = str(request_data['advisor'].get('gpuIndices'))
else: else:
request_data['tuner'] = experiment_config['tuner'] request_data['tuner'] = experiment_config['tuner']
if request_data['tuner'].get('gpuNum'):
print_error('gpuNum is deprecated, please use gpuIndices instead.')
if request_data['tuner'].get('gpuIndices') and isinstance(request_data['tuner'].get('gpuIndices'), int):
request_data['tuner']['gpuIndices'] = str(request_data['tuner'].get('gpuIndices'))
if 'assessor' in experiment_config: if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor'] request_data['assessor'] = experiment_config['assessor']
if request_data['assessor'].get('gpuNum'):
print_error('gpuNum is deprecated, please remove it from your config file.')
#debug mode should disable version check #debug mode should disable version check
if experiment_config.get('debug') is not None: if experiment_config.get('debug') is not None:
request_data['versionCheck'] = not experiment_config.get('debug') request_data['versionCheck'] = not experiment_config.get('debug')
...@@ -335,7 +349,6 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -335,7 +349,6 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']}) {'key': 'frameworkcontroller_config', 'value': experiment_config['frameworkcontrollerConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']}) {'key': 'trial_config', 'value': experiment_config['trial']})
response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True) response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
if check_response(response): if check_response(response):
return response return response
...@@ -347,6 +360,33 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -347,6 +360,33 @@ def set_experiment(experiment_config, mode, port, config_file_name):
print_error('Setting experiment error, error message is {}'.format(response.text)) print_error('Setting experiment error, error message is {}'.format(response.text))
return None return None
def set_platform_config(platform, experiment_config, port, config_file_name, rest_process):
'''call set_cluster_metadata for specific platform'''
print_normal('Setting {0} config...'.format(platform))
config_result, err_msg = None, None
if platform == 'local':
config_result, err_msg = set_local_config(experiment_config, port, config_file_name)
elif platform == 'remote':
config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
elif platform == 'pai':
config_result, err_msg = set_pai_config(experiment_config, port, config_file_name)
elif platform == 'kubeflow':
config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
elif platform == 'frameworkcontroller':
config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
else:
raise Exception(ERROR_INFO % 'Unsupported platform!')
exit(1)
if config_result:
print_normal('Successfully set {0} config!'.format(platform))
else:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None): def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
'''follow steps to start rest server and start experiment''' '''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name) nni_config = Config(config_file_name)
...@@ -371,8 +411,10 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -371,8 +411,10 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
exit(1) exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True): #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
log_level = 'debug' if mode != 'view':
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug'
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level) rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
...@@ -406,83 +448,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -406,83 +448,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
except Exception: except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!') raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1) exit(1)
if mode != 'view':
# set remote config # set platform configuration
if experiment_config['trainingServicePlatform'] == 'remote': set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port, config_file_name, rest_process)
print_normal('Setting remote config...')
config_result, err_msg = set_remote_config(experiment_config, args.port, config_file_name)
if config_result:
print_normal('Successfully set remote config!')
else:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
# set local config
if experiment_config['trainingServicePlatform'] == 'local':
print_normal('Setting local config...')
if set_local_config(experiment_config, args.port, config_file_name):
print_normal('Successfully set local config!')
else:
print_error('Set local config failed!')
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
#set pai config
if experiment_config['trainingServicePlatform'] == 'pai':
print_normal('Setting pai config...')
config_result, err_msg = set_pai_config(experiment_config, args.port, config_file_name)
if config_result:
print_normal('Successfully set pai config!')
else:
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
#set kubeflow config
if experiment_config['trainingServicePlatform'] == 'kubeflow':
print_normal('Setting kubeflow config...')
config_result, err_msg = set_kubeflow_config(experiment_config, args.port, config_file_name)
if config_result:
print_normal('Successfully set kubeflow config!')
else:
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
#set frameworkcontroller config
if experiment_config['trainingServicePlatform'] == 'frameworkcontroller':
print_normal('Setting frameworkcontroller config...')
config_result, err_msg = set_frameworkcontroller_config(experiment_config, args.port, config_file_name)
if config_result:
print_normal('Successfully set frameworkcontroller config!')
else:
if err_msg:
print_error('Failed! Error is: {}'.format(err_msg))
try:
kill_command(rest_process.pid)
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
# start a new experiment # start a new experiment
print_normal('Starting experiment...') print_normal('Starting experiment...')
# set debug configuration # set debug configuration
if experiment_config.get('debug') is None: if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug experiment_config['debug'] = args.debug
response = set_experiment(experiment_config, mode, args.port, config_file_name) response = set_experiment(experiment_config, mode, args.port, config_file_name)
if response: if response:
...@@ -509,8 +482,23 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -509,8 +482,23 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
def resume_experiment(args): def create_experiment(args):
'''resume an experiment''' '''start a new experiment'''
config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(config_file_name)
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
validate_all_content(experiment_config, config_path)
nni_config.set_config('experimentConfig', experiment_config)
launch_experiment(args, experiment_config, 'new', config_file_name)
nni_config.set_config('restServerPort', args.port)
def manage_stopped_experiment(args, mode):
'''view a stopped experiment'''
update_experiment() update_experiment()
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
...@@ -518,38 +506,31 @@ def resume_experiment(args): ...@@ -518,38 +506,31 @@ def resume_experiment(args):
experiment_endTime = None experiment_endTime = None
#find the latest stopped experiment #find the latest stopped experiment
if not args.id: if not args.id:
print_error('Please set experiment id! \nYou could use \'nnictl resume {id}\' to resume a stopped experiment!\n' \ print_error('Please set experiment id! \nYou could use \'nnictl {0} {id}\' to {0} a stopped experiment!\n' \
'You could use \'nnictl experiment list --all\' to show all experiments!') 'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
exit(1) exit(1)
else: else:
if experiment_dict.get(args.id) is None: if experiment_dict.get(args.id) is None:
print_error('Id %s not exist!' % args.id) print_error('Id %s not exist!' % args.id)
exit(1) exit(1)
if experiment_dict[args.id]['status'] != 'STOPPED': if experiment_dict[args.id]['status'] != 'STOPPED':
print_error('Only stopped experiments can be resumed!') print_error('Only stopped experiments can be {0}ed!'.format(mode))
exit(1) exit(1)
experiment_id = args.id experiment_id = args.id
print_normal('Resuming experiment %s...' % experiment_id) print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_dict[experiment_id]['fileName']) nni_config = Config(experiment_dict[experiment_id]['fileName'])
experiment_config = nni_config.get_config('experimentConfig') experiment_config = nni_config.get_config('experimentConfig')
experiment_id = nni_config.get_config('experimentId') experiment_id = nni_config.get_config('experimentId')
new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8)) new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
new_nni_config = Config(new_config_file_name) new_nni_config = Config(new_config_file_name)
new_nni_config.set_config('experimentConfig', experiment_config) new_nni_config.set_config('experimentConfig', experiment_config)
launch_experiment(args, experiment_config, 'resume', new_config_file_name, experiment_id) launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
new_nni_config.set_config('restServerPort', args.port) new_nni_config.set_config('restServerPort', args.port)
def create_experiment(args): def view_experiment(args):
'''start a new experiment''' '''view a stopped experiment'''
config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8)) manage_stopped_experiment(args, 'view')
nni_config = Config(config_file_name)
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
exit(1)
experiment_config = get_yml_content(config_path)
validate_all_content(experiment_config, config_path)
nni_config.set_config('experimentConfig', experiment_config) def resume_experiment(args):
launch_experiment(args, experiment_config, 'new', config_file_name) '''resume an experiment'''
nni_config.set_config('restServerPort', args.port) manage_stopped_experiment(args, 'resume')
\ No newline at end of file
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
import argparse import argparse
import pkg_resources import pkg_resources
from .launcher import create_experiment, resume_experiment from .launcher import create_experiment, resume_experiment, view_experiment
from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data
from .nnictl_utils import * from .nnictl_utils import *
from .package_management import * from .package_management import *
...@@ -66,6 +66,12 @@ def parse_args(): ...@@ -66,6 +66,12 @@ def parse_args():
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode') parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_resume.set_defaults(func=resume_experiment) parser_resume.set_defaults(func=resume_experiment)
# parse view command
parser_resume = subparsers.add_parser('view', help='view a stopped experiment')
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to view')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_resume.set_defaults(func=view_experiment)
# parse update command # parse update command
parser_updater = subparsers.add_parser('update', help='update the experiment') parser_updater = subparsers.add_parser('update', help='update the experiment')
#add subparsers for parser_updater #add subparsers for parser_updater
......
...@@ -351,6 +351,7 @@ def log_stderr(args): ...@@ -351,6 +351,7 @@ def log_stderr(args):
def log_trial(args): def log_trial(args):
''''get trial log path''' ''''get trial log path'''
trial_id_path_dict = {} trial_id_path_dict = {}
trial_id_list = []
nni_config = Config(get_config_filename(args)) nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort') rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
...@@ -363,23 +364,27 @@ def log_trial(args): ...@@ -363,23 +364,27 @@ def log_trial(args):
if response and check_response(response): if response and check_response(response):
content = json.loads(response.text) content = json.loads(response.text)
for trial in content: for trial in content:
trial_id_path_dict[trial['id']] = trial['logPath'] trial_id_list.append(trial.get('id'))
if trial.get('logPath'):
trial_id_path_dict[trial.get('id')] = trial['logPath']
else: else:
print_error('Restful server is not running...') print_error('Restful server is not running...')
exit(1) exit(1)
if args.id: if args.trial_id:
if args.trial_id: if args.trial_id not in trial_id_list:
if trial_id_path_dict.get(args.trial_id): print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id))
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id]) exit(1)
else: if trial_id_path_dict.get(args.trial_id):
print_error('trial id is not valid.') print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
exit(1)
else: else:
print_error('please specific the trial id.') print_error('Log path is not available yet, please wait...')
exit(1) exit(1)
else: else:
print_normal('All of trial log info:')
for key in trial_id_path_dict: for key in trial_id_path_dict:
print('id:' + key + ' path:' + trial_id_path_dict[key]) print_normal('id:' + key + ' path:' + trial_id_path_dict[key])
if not trial_id_path_dict:
print_normal('None')
def get_config(args): def get_config(args):
'''get config info''' '''get config info'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment