"driver/src/conv_driver.cpp" did not exist on "9657baec325227d0d64424bffb394afbd6d37a60"
Unverified Commit f04d423a authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support hybrid and V2 config pipieline (#3648)

parent 35c3d169
......@@ -87,6 +87,26 @@ remote:
port:
username:
trainingServicePlatform: remote
hybrid:
maxExecDuration: 15m
nniManagerIp:
maxTrialNum: 2
trialConcurrency: 2
trial:
gpuNum: 0
trainingServicePlatform: hybrid
hybridConfig:
# TODO: Add more platforms
trainingServicePlatforms:
- remote
- local
machineList:
- ip:
passwd:
port:
username:
remoteConfig:
reuse: true
adl:
maxExecDuration: 15m
nniManagerIp:
......
hybrid:
trainingService:
- platform: remote
machineList:
- host:
user:
password:
port:
- platform: local
\ No newline at end of file
experimentName: default_test
searchSpaceFile: seach_space_classic_nas.json
trialCommand: python3 mnist.py --epochs 1
trialCodeDirectory: ../../../examples/nas/legacy/classic_nas
trialGpuNumber: 0
trialConcurrency: 1
maxExperimentDuration: 15m
maxTrialNumber: 1
tuner:
name: RegularizedEvolutionTuner
classArgs:
optimize_mode: maximize
trainingService:
platform: local
......@@ -8,10 +8,11 @@ import argparse
from utils import get_yml_content, dump_yml_content
TRAINING_SERVICE_FILE = os.path.join('config', 'training_service.yml')
TRAINING_SERVICE_FILE_V2 = os.path.join('config', 'training_service_v2.yml')
def update_training_service_config(args):
config = get_yml_content(TRAINING_SERVICE_FILE)
if args.nni_manager_ip is not None:
if args.nni_manager_ip is not None and args.config_version == 'v1':
config[args.ts]['nniManagerIp'] = args.nni_manager_ip
if args.ts == 'pai':
if args.pai_user is not None:
......@@ -99,13 +100,22 @@ def update_training_service_config(args):
config[args.ts]['amlConfig']['workspaceName'] = args.workspace_name
if args.compute_target is not None:
config[args.ts]['amlConfig']['computeTarget'] = args.compute_target
dump_yml_content(TRAINING_SERVICE_FILE, config)
if args.ts == 'hybrid':
config = get_yml_content(TRAINING_SERVICE_FILE_V2)
config[args.ts]['trainingService'][0]['machineList'][0]['user'] = args.remote_user
config[args.ts]['trainingService'][0]['machineList'][0]['host'] = args.remote_host
config[args.ts]['trainingService'][0]['machineList'][0]['password'] = args.remote_pwd
config[args.ts]['trainingService'][0]['machineList'][0]['port'] = args.remote_port
config[args.ts]['nni_manager_ip'] = args.nni_manager_ip
dump_yml_content(TRAINING_SERVICE_FILE_V2, config)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--ts", type=str, choices=['pai', 'kubeflow', 'remote', 'local', 'frameworkcontroller', 'adl', 'aml'], default='pai')
parser.add_argument("--ts", type=str, choices=['pai', 'kubeflow', 'remote', 'local', 'frameworkcontroller', 'adl', 'aml', 'hybrid'], default='pai')
parser.add_argument("--config_version", type=str, choices=['v1', 'v2'], default='v1')
parser.add_argument("--nni_docker_image", type=str)
parser.add_argument("--nni_manager_ip", type=str)
# args for PAI
......
......@@ -52,8 +52,11 @@ def update_training_service_config(config, training_service, config_file_path):
containerCodeDir = config['trial']['codeDir'].replace('../../../', '/')
it_ts_config[training_service]['trial']['codeDir'] = containerCodeDir
it_ts_config[training_service]['trial']['command'] = 'cd {0} && {1}'.format(containerCodeDir, config['trial']['command'])
deep_update(config, it_ts_config['all'])
if training_service == 'hybrid':
it_ts_config = get_yml_content(os.path.join('config', 'training_service_v2.yml'))
else:
deep_update(config, it_ts_config['all'])
deep_update(config, it_ts_config[training_service])
......@@ -123,7 +126,10 @@ def invoke_validator(test_case_config, nni_source_dir, training_service):
def get_max_values(config_file):
experiment_config = get_yml_content(config_file)
return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']
if experiment_config.get('maxExecDuration'):
return parse_max_duration_time(experiment_config['maxExecDuration']), experiment_config['maxTrialNum']
else:
return parse_max_duration_time(experiment_config['maxExperimentDuration']), experiment_config['maxTrialNumber']
def get_command(test_case_config, commandKey):
......@@ -259,7 +265,7 @@ def run(args):
name, args.ts, test_case_config['trainingService']))
continue
# remote mode need more time to cleanup
if args.ts == 'remote':
if args.ts == 'remote' or args.ts == 'hybrid':
wait_for_port_available(8080, 240)
else:
wait_for_port_available(8080, 60)
......@@ -281,7 +287,7 @@ if __name__ == '__main__':
parser.add_argument("--cases", type=str, default=None)
parser.add_argument("--exclude", type=str, default=None)
parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai',
'kubeflow', 'frameworkcontroller', 'adl', 'aml'], default='local')
'kubeflow', 'frameworkcontroller', 'adl', 'aml', 'hybrid'], default='local')
args = parser.parse_args()
run(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment