Unverified Commit 67e23897 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix config v2 bugs (#3540)

parent 1418a366
...@@ -123,7 +123,7 @@ class Experiments: ...@@ -123,7 +123,7 @@ class Experiments:
self.experiments[expId]['tag'] = tag self.experiments[expId]['tag'] = tag
self.experiments[expId]['pid'] = pid self.experiments[expId]['pid'] = pid
self.experiments[expId]['webuiUrl'] = webuiUrl self.experiments[expId]['webuiUrl'] = webuiUrl
self.experiments[expId]['logDir'] = logDir self.experiments[expId]['logDir'] = str(logDir)
self.write_file() self.write_file()
def update_experiment(self, expId, key, value): def update_experiment(self, expId, key, value):
......
...@@ -411,6 +411,21 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi ...@@ -411,6 +411,21 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
kill_command(rest_process.pid) kill_command(rest_process.pid)
print_normal('Stopping experiment...') print_normal('Stopping experiment...')
def _validate_v1(config, path):
try:
validate_all_content(config, path)
except Exception as e:
print_error(f'Config V1 validation failed: {repr(e)}')
exit(1)
def _validate_v2(config, path):
base_path = Path(path).parent
try:
conf = ExperimentConfig(_base_path=base_path, **config)
return conf.json()
except Exception as e:
print_error(f'Config V2 validation failed: {repr(e)}')
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8)) experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
...@@ -420,23 +435,23 @@ def create_experiment(args): ...@@ -420,23 +435,23 @@ def create_experiment(args):
exit(1) exit(1)
config_yml = get_yml_content(config_path) config_yml = get_yml_content(config_path)
try: if 'trainingServicePlatform' in config_yml:
config = ExperimentConfig(_base_path=Path(config_path).parent, **config_yml) _validate_v1(config_yml, config_path)
config_v2 = config.json() platform = config_yml['trainingServicePlatform']
except Exception as error_v2: if platform in k8s_training_services:
print_warning('Validation with V2 schema failed. Trying to convert from V1 format...') schema = 1
try: config_v1 = config_yml
validate_all_content(config_yml, config_path) else:
except Exception as error_v1: schema = 2
print_error(f'Convert from v1 format failed: {repr(error_v1)}') from nni.experiment.config import convert
print_error(f'Config in v2 format validation failed: {repr(error_v2)}') config_v2 = convert.to_v2(config_yml).json()
exit(1) else:
from nni.experiment.config import convert config_v2 = _validate_v2(config_yml, config_path)
config_v2 = convert.to_v2(config_yml).json() schema = 2
try: try:
if getattr(config_v2['trainingService'], 'platform', None) in k8s_training_services: if schema == 1:
launch_experiment(args, config_yml, 'new', experiment_id, 1) launch_experiment(args, config_v1, 'new', experiment_id, 1)
else: else:
launch_experiment(args, config_v2, 'new', experiment_id, 2) launch_experiment(args, config_v2, 'new', experiment_id, 2)
except Exception as exception: except Exception as exception:
......
...@@ -13,7 +13,6 @@ from functools import cmp_to_key ...@@ -13,7 +13,6 @@ from functools import cmp_to_key
import traceback import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
from subprocess import Popen from subprocess import Popen
from pyhdfs import HdfsClient
from nni.tools.annotation import expand_annotations from nni.tools.annotation import expand_annotations
import nni_node # pylint: disable=import-error import nni_node # pylint: disable=import-error
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
...@@ -501,30 +500,6 @@ def remote_clean(machine_list, experiment_id=None): ...@@ -501,30 +500,6 @@ def remote_clean(machine_list, experiment_id=None):
print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir)) print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
remove_remote_directory(sftp, remote_dir) remove_remote_directory(sftp, remote_dir)
def hdfs_clean(host, user_name, output_dir, experiment_id=None):
'''clean up hdfs data'''
hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5)
if experiment_id:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments', experiment_id])
else:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments'])
print_normal('removing folder {0} in hdfs'.format(full_path))
hdfs_client.delete(full_path, recursive=True)
if output_dir:
pattern = re.compile('hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?')
match_result = pattern.match(output_dir)
if match_result:
output_host = match_result.group('host')
output_dir = match_result.group('baseDir')
#check if the host is valid
if output_host != host:
print_warning('The host in {0} is not consistent with {1}'.format(output_dir, host))
else:
if experiment_id:
output_dir = output_dir + '/' + experiment_id
print_normal('removing folder {0} in hdfs'.format(output_dir))
hdfs_client.delete(output_dir, recursive=True)
def experiment_clean(args): def experiment_clean(args):
'''clean up the experiment data''' '''clean up the experiment data'''
experiment_id_list = [] experiment_id_list = []
...@@ -556,11 +531,6 @@ def experiment_clean(args): ...@@ -556,11 +531,6 @@ def experiment_clean(args):
if platform == 'remote': if platform == 'remote':
machine_list = experiment_config.get('machineList') machine_list = experiment_config.get('machineList')
remote_clean(machine_list, experiment_id) remote_clean(machine_list, experiment_id)
elif platform == 'pai':
host = experiment_config.get('paiConfig').get('host')
user_name = experiment_config.get('paiConfig').get('userName')
output_dir = experiment_config.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local': elif platform != 'local':
# TODO: support all platforms # TODO: support all platforms
print_warning('platform {0} clean up not supported yet.'.format(platform)) print_warning('platform {0} clean up not supported yet.'.format(platform))
...@@ -632,11 +602,6 @@ def platform_clean(args): ...@@ -632,11 +602,6 @@ def platform_clean(args):
if platform == 'remote': if platform == 'remote':
machine_list = config_content.get('machineList') machine_list = config_content.get('machineList')
remote_clean(machine_list) remote_clean(machine_list)
elif platform == 'pai':
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir)
print_normal('Done.') print_normal('Done.')
def experiment_list(args): def experiment_list(args):
......
...@@ -254,12 +254,15 @@ class NNIManager implements Manager { ...@@ -254,12 +254,15 @@ class NNIManager implements Manager {
return this.dataStore.getTrialJob(trialJobId); return this.dataStore.getTrialJob(trialJobId);
} }
public async setClusterMetadata(_key: string, _value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
throw new Error('Calling removed API setClusterMetadata'); while (this.trainingService === undefined) {
await delay(1000);
}
this.trainingService.setClusterMetadata(key, value);
} }
public getClusterMetadata(_key: string): Promise<string> { public getClusterMetadata(key: string): Promise<string> {
throw new Error('Calling removed API getClusterMetadata'); return this.trainingService.getClusterMetadata(key);
} }
public async getTrialJobStatistics(): Promise<TrialJobStatistics[]> { public async getTrialJobStatistics(): Promise<TrialJobStatistics[]> {
......
...@@ -128,6 +128,10 @@ export class EnvironmentInformation { ...@@ -128,6 +128,10 @@ export class EnvironmentInformation {
export abstract class EnvironmentService { export abstract class EnvironmentService {
public async init(): Promise<void> {
return;
}
public abstract get hasStorageService(): boolean; public abstract get hasStorageService(): boolean;
public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>; public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>;
public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>; public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>;
......
...@@ -27,7 +27,7 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -27,7 +27,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
private readonly environmentExecutorManagerMap: Map<string, ExecutorManager>; private readonly environmentExecutorManagerMap: Map<string, ExecutorManager>;
private readonly remoteMachineMetaOccupiedMap: Map<RemoteMachineConfig, boolean>; private readonly remoteMachineMetaOccupiedMap: Map<RemoteMachineConfig, boolean>;
private readonly log: Logger; private readonly log: Logger;
private sshConnectionPromises: any[]; private sshConnectionPromises: Promise<void[]>;
private experimentRootDir: string; private experimentRootDir: string;
private remoteExperimentRootDir: string = ""; private remoteExperimentRootDir: string = "";
private experimentId: string; private experimentId: string;
...@@ -39,7 +39,6 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -39,7 +39,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
this.environmentExecutorManagerMap = new Map<string, ExecutorManager>(); this.environmentExecutorManagerMap = new Map<string, ExecutorManager>();
this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>(); this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>();
this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>(); this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>();
this.sshConnectionPromises = [];
this.experimentRootDir = getExperimentRootDir(); this.experimentRootDir = getExperimentRootDir();
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.log = getLogger(); this.log = getLogger();
...@@ -50,9 +49,18 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -50,9 +49,18 @@ export class RemoteEnvironmentService extends EnvironmentService {
throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`); throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`);
} }
this.sshConnectionPromises = this.config.machineList.map( this.sshConnectionPromises = Promise.all(this.config.machineList.map(
machine => this.initRemoteMachineOnConnected(machine) machine => this.initRemoteMachineOnConnected(machine)
); ));
}
public async init(): Promise<void> {
await this.sshConnectionPromises;
this.log.info('ssh connection initialized!');
Array.from(this.machineExecutorManagerMap.keys()).forEach(rmMeta => {
// initialize remoteMachineMetaOccupiedMap, false means not occupied
this.remoteMachineMetaOccupiedMap.set(rmMeta, false);
});
} }
public get prefetchedEnvironmentCount(): number { public get prefetchedEnvironmentCount(): number {
...@@ -204,16 +212,6 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -204,16 +212,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
} }
public async startEnvironment(environment: EnvironmentInformation): Promise<void> { public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.sshConnectionPromises.length > 0) {
await Promise.all(this.sshConnectionPromises);
this.log.info('ssh connection initialized!');
// set sshConnectionPromises to [] to avoid log information duplicated
this.sshConnectionPromises = [];
Array.from(this.machineExecutorManagerMap.keys()).forEach(rmMeta => {
// initialize remoteMachineMetaOccupiedMap, false means not occupied
this.remoteMachineMetaOccupiedMap.set(rmMeta, false);
});
}
const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation; const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation;
remoteEnvironment.status = 'WAITING'; remoteEnvironment.status = 'WAITING';
// schedule machine for environment, generate command // schedule machine for environment, generate command
......
...@@ -122,7 +122,6 @@ class TrialDispatcher implements TrainingService { ...@@ -122,7 +122,6 @@ class TrialDispatcher implements TrainingService {
this.environmentServiceList.push(env); this.environmentServiceList.push(env);
} }
// FIXME: max?
this.environmentMaintenceLoopInterval = Math.max( this.environmentMaintenceLoopInterval = Math.max(
...this.environmentServiceList.map((env) => env.environmentMaintenceLoopInterval) ...this.environmentServiceList.map((env) => env.environmentMaintenceLoopInterval)
); );
...@@ -211,6 +210,7 @@ class TrialDispatcher implements TrainingService { ...@@ -211,6 +210,7 @@ class TrialDispatcher implements TrainingService {
} }
public async run(): Promise<void> { public async run(): Promise<void> {
await Promise.all(this.environmentServiceList.map(env => env.init()));
for(const environmentService of this.environmentServiceList) { for(const environmentService of this.environmentServiceList) {
const runnerSettings: RunnerSettings = new RunnerSettings(); const runnerSettings: RunnerSettings = new RunnerSettings();
...@@ -497,9 +497,10 @@ class TrialDispatcher implements TrainingService { ...@@ -497,9 +497,10 @@ class TrialDispatcher implements TrainingService {
liveEnvironmentsCount++; liveEnvironmentsCount++;
if (environment.status === "RUNNING" && environment.isRunnerReady) { if (environment.status === "RUNNING" && environment.isRunnerReady) {
// if environment is not reusable and used, stop and not count as idle; // if environment is not reusable and used, stop and not count as idle;
const reuseMode = Array.isArray(this.config.trainingService) || (this.config.trainingService as any).reuseMode;
if ( if (
0 === environment.runningTrialCount && 0 === environment.runningTrialCount &&
!(this.config as any).reuseMode && !reuseMode &&
environment.assignedTrialCount > 0 environment.assignedTrialCount > 0
) { ) {
if (environment.environmentService === undefined) { if (environment.environmentService === undefined) {
......
...@@ -101,13 +101,7 @@ export const EditExperimentParam = (): any => { ...@@ -101,13 +101,7 @@ export const EditExperimentParam = (): any => {
} }
if (isMaxDuration) { if (isMaxDuration) {
const maxDura = JSON.parse(editInputVal); const maxDura = JSON.parse(editInputVal);
if (unit === 'm') { newProfile.params[field] = `${maxDura}${unit}`;
newProfile.params[field] = maxDura * 60;
} else if (unit === 'h') {
newProfile.params[field] = maxDura * 3600;
} else {
newProfile.params[field] = maxDura * 24 * 60 * 60;
}
} else { } else {
newProfile.params[field] = parseInt(editInputVal, 10); newProfile.params[field] = parseInt(editInputVal, 10);
} }
...@@ -162,7 +156,7 @@ export const EditExperimentParam = (): any => { ...@@ -162,7 +156,7 @@ export const EditExperimentParam = (): any => {
<EditExpeParamContext.Consumer> <EditExpeParamContext.Consumer>
{(value): React.ReactNode => { {(value): React.ReactNode => {
let editClassName = ''; let editClassName = '';
if (value.field === 'maxExecDuration') { if (value.field === 'maxExperimentDuration') {
editClassName = isShowPencil ? 'noEditDuration' : 'editDuration'; editClassName = isShowPencil ? 'noEditDuration' : 'editDuration';
} }
return ( return (
......
...@@ -50,7 +50,7 @@ export const ExpDuration = (): any => ( ...@@ -50,7 +50,7 @@ export const ExpDuration = (): any => (
<EditExpeParamContext.Provider <EditExpeParamContext.Provider
value={{ value={{
editType: CONTROLTYPE[0], editType: CONTROLTYPE[0],
field: 'maxExecDuration', field: 'maxExperimentDuration',
title: 'Max duration', title: 'Max duration',
maxExecDuration: maxExecDurationStr, maxExecDuration: maxExecDurationStr,
maxTrialNum: EXPERIMENT.maxTrialNumber, maxTrialNum: EXPERIMENT.maxTrialNumber,
......
...@@ -89,7 +89,7 @@ export const TrialCount = (): any => { ...@@ -89,7 +89,7 @@ export const TrialCount = (): any => {
<EditExpeParamContext.Provider <EditExpeParamContext.Provider
value={{ value={{
title: MAX_TRIAL_NUMBERS, title: MAX_TRIAL_NUMBERS,
field: 'maxTrialNum', field: 'maxTrialNumber',
editType: CONTROLTYPE[1], editType: CONTROLTYPE[1],
maxExecDuration: '', maxExecDuration: '',
maxTrialNum: EXPERIMENT.maxTrialNumber, maxTrialNum: EXPERIMENT.maxTrialNumber,
......
...@@ -73,9 +73,9 @@ class TrialConfigPanel extends React.Component<LogDrawerProps, LogDrawerState> { ...@@ -73,9 +73,9 @@ class TrialConfigPanel extends React.Component<LogDrawerProps, LogDrawerState> {
<AppContext.Consumer> <AppContext.Consumer>
{(value): React.ReactNode => { {(value): React.ReactNode => {
const unit = value.maxDurationUnit; const unit = value.maxDurationUnit;
profile.params.maxExecDuration = `${convertTimeAsUnit( profile.params.maxExperimentDuration = `${convertTimeAsUnit(
unit, unit,
profile.params.maxExecDuration profile.params.maxExperimentDuration
)}${unit}`; )}${unit}`;
const showProfile = JSON.stringify(profile, filter, 2); const showProfile = JSON.stringify(profile, filter, 2);
return ( return (
......
...@@ -152,7 +152,10 @@ export interface ExperimentConfig { ...@@ -152,7 +152,10 @@ export interface ExperimentConfig {
const timeUnits = { d: 24 * 3600, h: 3600, m: 60, s: 1 }; const timeUnits = { d: 24 * 3600, h: 3600, m: 60, s: 1 };
export function toSeconds(time: string): number { export function toSeconds(time: string | number): number {
if (typeof time === 'number') {
return time;
}
for (const [unit, factor] of Object.entries(timeUnits)) { for (const [unit, factor] of Object.entries(timeUnits)) {
if (time.endsWith(unit)) { if (time.endsWith(unit)) {
const digits = time.slice(0, -1); const digits = time.slice(0, -1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment