Unverified Commit a5764016 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Install builtin tuners (#2439)

parent 0f7f9460
......@@ -9,8 +9,8 @@ from subprocess import call, Popen
from .rest_utils import rest_get, check_rest_server_quick, check_response
from .config_utils import Config, Experiments
from .url_utils import trial_jobs_url, get_local_urls
from .constants import COLOR_GREEN_FORMAT, REST_TIME_OUT
from .common_utils import print_normal, print_error, detect_process, detect_port, check_tensorboard_version
from .constants import REST_TIME_OUT
from .common_utils import print_normal, print_error, print_green, detect_process, detect_port, check_tensorboard_version
from .nnictl_utils import check_experiment_id, check_experiment_id
from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local
......@@ -81,7 +81,8 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
cmds = ['tensorboard', log_dir_cmd, format_tensorboard_log_path(path_list), '--port', str(args.port)]
tensorboard_process = Popen(cmds, stdout=stdout_file, stderr=stderr_file)
url_list = get_local_urls(args.port)
print_normal(COLOR_GREEN_FORMAT % 'Start tensorboard success!\n' + 'Tensorboard urls: ' + ' '.join(url_list))
print_green('Start tensorboard success!')
print_normal('Tensorboard urls: ' + ' '.join(url_list))
tensorboard_process_pid_list = nni_config.get_config('tensorboardPidList')
if tensorboard_process_pid_list is None:
tensorboard_process_pid_list = [tensorboard_process.pid]
......
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 2
trialConcurrency: 2
searchSpacePath: ./search_space.json
# error: no className
tuner:
codeDir: ./
classFileName: mytuner.py
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ./
command: python3 main.py
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 2
trialConcurrency: 2
searchSpacePath: ./search_space.json
# error: builtinTunerName conflicts with custom tuner settings
tuner:
codeDir: ./
classFileName: mytuner.py
className: MyTuner
builtinTunerName: Random
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ./
command: python3 main.py
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
from nni import Tuner
class MyTuner(Tuner):
def __init__(self):
pass
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 2
trialConcurrency: 2
searchSpacePath: ./search_space.json
# error: no tuner or advisor
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ./
command: python3 main.py
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
{
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]},
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]},
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},
"momentum":{"_type":"uniform","_value":[0, 1]}
}
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 2
trialConcurrency: 2
# error: searchSpacePath can not be found
searchSpacePath: ./wrong_search_space.json
tuner:
builtinTunerName: Random
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ./
command: python3 mnist.py --epochs 1 --batch_num 10
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 2
trialConcurrency: 2
searchSpacePath: ./search_space.json
tuner:
# error: wrong key
wrongTunerKey: abc
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ./
command: python3 main.py
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 2
trialConcurrency: 2
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: Random
assessor:
builtinAssessorName: Medianstop
classArgs:
# wrong class args, should be detected by assessor validator
optimize_mode: aaaaaa
trial:
codeDir: ./
command: python3 main.py
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 2
trialConcurrency: 2
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: Random
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ./
command: python3 main.py
useAnnotation: false
multiPhase: false
multiThread: false
# error: wrong training service name
trainingServicePlatform: local222
{
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]},
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]},
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},
"momentum":{"_type":"uniform","_value":[0, 1]}
}
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 2
trialConcurrency: 2
searchSpacePath: ./search_space.json
tuner:
builtinTunerName: Random
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ./
command: python3 main.py
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import glob
from unittest import TestCase, main
from schema import SchemaError
from nni_cmd.launcher_utils import validate_all_content
from nni_cmd.nnictl_utils import get_yml_content
from nni_cmd.common_utils import print_error, print_green
class ConfigValidationTestCase(TestCase):
def test_valid_config(self):
file_names = glob.glob('./config_files/valid/*.yml')
for fn in file_names:
experiment_config = get_yml_content(fn)
validate_all_content(experiment_config, fn)
print_green('config file:', fn, 'validation success!')
def test_invalid_config(self):
file_names = glob.glob('./config_files/invalid/*.yml')
for fn in file_names:
experiment_config = get_yml_content(fn)
try:
validate_all_content(experiment_config, fn)
print_error('config file:', fn,'Schema error should be raised for invalid config file!')
assert False
except SchemaError as e:
print_green('config file:', fn, 'Expected error catched:', e)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment