"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "feb6f3b84ddcf2444afa29d5ad098fff8c08d7b1"
Commit c8ef4141 authored by Zejun Lin's avatar Zejun Lin Committed by SparkSnail
Browse files

Implement API for user to import data and export data of type `json` or `csv` (#980)

parent 1d9b0a99
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
import { ExperimentProfile, TrialJobStatistics } from './manager'; import { ExperimentProfile, TrialJobStatistics } from './manager';
import { TrialJobDetail, TrialJobStatus } from './trainingService'; import { TrialJobDetail, TrialJobStatus } from './trainingService';
type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED' | 'ADD_HYPERPARAMETER'; type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED' | 'ADD_HYPERPARAMETER' | 'IMPORT_DATA';
type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM' | 'REQUEST_PARAMETER'; type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM' | 'REQUEST_PARAMETER';
interface ExperimentProfileRecord { interface ExperimentProfileRecord {
......
...@@ -99,6 +99,7 @@ abstract class Manager { ...@@ -99,6 +99,7 @@ abstract class Manager {
public abstract stopExperiment(): Promise<void>; public abstract stopExperiment(): Promise<void>;
public abstract getExperimentProfile(): Promise<ExperimentProfile>; public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>; public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract importData(data: string): Promise<void>;
public abstract addCustomizedTrialJob(hyperParams: string): Promise<void>; public abstract addCustomizedTrialJob(hyperParams: string): Promise<void>;
public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>; public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>;
......
...@@ -22,6 +22,7 @@ const INITIALIZE = 'IN'; ...@@ -22,6 +22,7 @@ const INITIALIZE = 'IN';
const REQUEST_TRIAL_JOBS = 'GE'; const REQUEST_TRIAL_JOBS = 'GE';
const REPORT_METRIC_DATA = 'ME'; const REPORT_METRIC_DATA = 'ME';
const UPDATE_SEARCH_SPACE = 'SS'; const UPDATE_SEARCH_SPACE = 'SS';
const IMPORT_DATA = 'FD'
const ADD_CUSTOMIZED_TRIAL_JOB = 'AD'; const ADD_CUSTOMIZED_TRIAL_JOB = 'AD';
const TRIAL_END = 'EN'; const TRIAL_END = 'EN';
const TERMINATE = 'TE'; const TERMINATE = 'TE';
...@@ -38,6 +39,7 @@ const TUNER_COMMANDS: Set<string> = new Set([ ...@@ -38,6 +39,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
REQUEST_TRIAL_JOBS, REQUEST_TRIAL_JOBS,
REPORT_METRIC_DATA, REPORT_METRIC_DATA,
UPDATE_SEARCH_SPACE, UPDATE_SEARCH_SPACE,
IMPORT_DATA,
ADD_CUSTOMIZED_TRIAL_JOB, ADD_CUSTOMIZED_TRIAL_JOB,
TERMINATE, TERMINATE,
PING, PING,
...@@ -62,6 +64,7 @@ export { ...@@ -62,6 +64,7 @@ export {
REQUEST_TRIAL_JOBS, REQUEST_TRIAL_JOBS,
REPORT_METRIC_DATA, REPORT_METRIC_DATA,
UPDATE_SEARCH_SPACE, UPDATE_SEARCH_SPACE,
IMPORT_DATA,
ADD_CUSTOMIZED_TRIAL_JOB, ADD_CUSTOMIZED_TRIAL_JOB,
TRIAL_END, TRIAL_END,
TERMINATE, TERMINATE,
......
...@@ -38,7 +38,7 @@ import { ...@@ -38,7 +38,7 @@ import {
import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getLogLevel } from '../common/utils'; import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getLogLevel } from '../common/utils';
import { import {
ADD_CUSTOMIZED_TRIAL_JOB, INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, PING, ADD_CUSTOMIZED_TRIAL_JOB, INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, PING,
REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE, IMPORT_DATA
} from './commands'; } from './commands';
import { createDispatcherInterface, IpcInterface } from './ipcInterface'; import { createDispatcherInterface, IpcInterface } from './ipcInterface';
...@@ -99,6 +99,17 @@ class NNIManager implements Manager { ...@@ -99,6 +99,17 @@ class NNIManager implements Manager {
return this.storeExperimentProfile(); return this.storeExperimentProfile();
} }
public importData(data: string): Promise<void> {
if (this.dispatcher === undefined) {
return Promise.reject(
new Error('tuner has not been setup')
);
}
this.dispatcher.sendCommand(IMPORT_DATA, data);
return this.dataStore.storeTrialJobEvent('IMPORT_DATA', '', data);
}
public addCustomizedTrialJob(hyperParams: string): Promise<void> { public addCustomizedTrialJob(hyperParams: string): Promise<void> {
if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
return Promise.reject( return Promise.reject(
......
...@@ -63,6 +63,7 @@ class NNIRestHandler { ...@@ -63,6 +63,7 @@ class NNIRestHandler {
this.checkStatus(router); this.checkStatus(router);
this.getExperimentProfile(router); this.getExperimentProfile(router);
this.updateExperimentProfile(router); this.updateExperimentProfile(router);
this.importData(router);
this.startExperiment(router); this.startExperiment(router);
this.getTrialJobStatistics(router); this.getTrialJobStatistics(router);
this.setClusterMetaData(router); this.setClusterMetaData(router);
...@@ -144,6 +145,16 @@ class NNIRestHandler { ...@@ -144,6 +145,16 @@ class NNIRestHandler {
}); });
}); });
} }
private importData(router: Router): void {
router.post('/experiment/import-data', (req: Request, res: Response) => {
this.nniManager.importData(JSON.stringify(req.body)).then(() => {
res.send();
}).catch((err: Error) => {
this.handle_error(err, res);
});
});
}
private startExperiment(router: Router): void { private startExperiment(router: Router): void {
router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => { router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => {
......
...@@ -46,6 +46,9 @@ export class MockedNNIManager extends Manager { ...@@ -46,6 +46,9 @@ export class MockedNNIManager extends Manager {
public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> { public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> {
return Promise.resolve(); return Promise.resolve();
} }
public importData(data: string): Promise<void> {
return Promise.resolve();
}
public getTrialJobStatistics(): Promise<TrialJobStatistics[]> { public getTrialJobStatistics(): Promise<TrialJobStatistics[]> {
const deferred: Deferred<TrialJobStatistics[]> = new Deferred<TrialJobStatistics[]>(); const deferred: Deferred<TrialJobStatistics[]> = new Deferred<TrialJobStatistics[]>();
deferred.resolve([{ deferred.resolve([{
......
...@@ -109,18 +109,24 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -109,18 +109,24 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
def handle_import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
# data: parameters # data: parameters
id_ = _create_parameter_id() id_ = _create_parameter_id()
_customized_parameter_ids.add(id_) _customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True)) send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
""" """
:param data: a dict received from nni_manager, which contains: data: a dict received from nni_manager, which contains:
- 'parameter_id': id of the trial - 'parameter_id': id of the trial
- 'value': metric value reported by nni.report_final_result() - 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
if data['type'] == 'FINAL': if data['type'] == 'FINAL':
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
...@@ -135,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -135,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_trial_end(self, data): def handle_trial_end(self, data):
""" """
data: it has three keys: trial_job_id, event, hyper_params data: it has three keys: trial_job_id, event, hyper_params
trial_job_id: the id generated by training service - trial_job_id: the id generated by training service
event: the job's state - event: the job's state
hyper_params: the hyperparameters generated and returned by tuner - hyper_params: the hyperparameters generated and returned by tuner
""" """
trial_job_id = data['trial_job_id'] trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id) _ended_trials.add(trial_job_id)
......
...@@ -144,6 +144,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -144,6 +144,7 @@ class MsgDispatcherBase(Recoverable):
CommandType.Initialize: self.handle_initialize, CommandType.Initialize: self.handle_initialize,
CommandType.RequestTrialJobs: self.handle_request_trial_jobs, CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space, CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.ImportData: self.handle_import_data,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial, CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
# Tunner/Assessor commands: # Tunner/Assessor commands:
...@@ -168,6 +169,9 @@ class MsgDispatcherBase(Recoverable): ...@@ -168,6 +169,9 @@ class MsgDispatcherBase(Recoverable):
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
raise NotImplementedError('handle_update_search_space not implemented') raise NotImplementedError('handle_update_search_space not implemented')
def handle_import_data(self, data):
raise NotImplementedError('handle_import_data not implemented')
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
raise NotImplementedError('handle_add_customized_trial not implemented') raise NotImplementedError('handle_add_customized_trial not implemented')
......
...@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase): ...@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
return True return True
def handle_import_data(self, data):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
return True
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
# data: parameters # data: parameters
id_ = _create_parameter_id() id_ = _create_parameter_id()
......
...@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable): ...@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable):
""" """
raise NotImplementedError('Tuner: update_search_space not implemented') raise NotImplementedError('Tuner: update_search_space not implemented')
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def load_checkpoint(self): def load_checkpoint(self):
"""Load the checkpoint of tuner. """Load the checkpoint of tuner.
path: checkpoint directory for tuner path: checkpoint directory for tuner
......
...@@ -30,6 +30,7 @@ class CommandType(Enum): ...@@ -30,6 +30,7 @@ class CommandType(Enum):
RequestTrialJobs = b'GE' RequestTrialJobs = b'GE'
ReportMetricData = b'ME' ReportMetricData = b'ME'
UpdateSearchSpace = b'SS' UpdateSearchSpace = b'SS'
ImportData = b'FD'
AddCustomizedTrialJob = b'AD' AddCustomizedTrialJob = b'AD'
TrialEnd = b'EN' TrialEnd = b'EN'
Terminate = b'TE' Terminate = b'TE'
......
...@@ -98,6 +98,12 @@ class Tuner(Recoverable): ...@@ -98,6 +98,12 @@ class Tuner(Recoverable):
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path) _logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def _on_exit(self): def _on_exit(self):
pass pass
......
...@@ -80,6 +80,20 @@ PACKAGE_REQUIREMENTS = { ...@@ -80,6 +80,20 @@ PACKAGE_REQUIREMENTS = {
'BOHB': 'bohb_advisor' 'BOHB': 'bohb_advisor'
} }
TUNERS_SUPPORTING_IMPORT_DATA = {
'TPE',
'Anneal',
'GridSearch',
'MetisTuner',
'BOHB'
}
TUNERS_NO_NEED_TO_IMPORT_DATA = {
'Random',
'Batch_tuner',
'Hyperband'
}
COLOR_RED_FORMAT = '\033[1;31;31m%s\033[0m' COLOR_RED_FORMAT = '\033[1;31;31m%s\033[0m'
COLOR_GREEN_FORMAT = '\033[1;32;32m%s\033[0m' COLOR_GREEN_FORMAT = '\033[1;32;32m%s\033[0m'
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
import argparse import argparse
import pkg_resources import pkg_resources
from .launcher import create_experiment, resume_experiment from .launcher import create_experiment, resume_experiment
from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data
from .nnictl_utils import * from .nnictl_utils import *
from .package_management import * from .package_management import *
from .constants import * from .constants import *
...@@ -101,10 +101,6 @@ def parse_args(): ...@@ -101,10 +101,6 @@ def parse_args():
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment') parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed') parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed')
parser_trial_kill.set_defaults(func=trial_kill) parser_trial_kill.set_defaults(func=trial_kill)
parser_trial_export = parser_trial_subparsers.add_parser('export', help='export trial job results to csv')
parser_trial_export.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_export.add_argument('--file', '-f', required=True, dest='csv_path', help='target csv file path')
parser_trial_export.set_defaults(func=export_trials_data)
#parse experiment command #parse experiment command
parser_experiment = subparsers.add_parser('experiment', help='get experiment information') parser_experiment = subparsers.add_parser('experiment', help='get experiment information')
...@@ -119,6 +115,17 @@ def parse_args(): ...@@ -119,6 +115,17 @@ def parse_args():
parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids') parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids')
parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments') parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments')
parser_experiment_list.set_defaults(func=experiment_list) parser_experiment_list.set_defaults(func=experiment_list)
#import tuning data
parser_import_data = parser_experiment_subparsers.add_parser('import', help='import additional data')
parser_import_data.add_argument('id', nargs='?', help='the id of experiment')
parser_import_data.add_argument('--filename', '-f', required=True)
parser_import_data.set_defaults(func=import_data)
#export trial data
parser_trial_export = parser_experiment_subparsers.add_parser('export', help='export trial job results to csv or json')
parser_trial_export.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_export.add_argument('--type', '-t', choices=['json', 'csv'], required=True, dest='type', help='target file type')
parser_trial_export.add_argument('--filename', '-f', required=True, dest='path', help='target file path')
parser_trial_export.set_defaults(func=export_trials_data)
#TODO:finish webui function #TODO:finish webui function
#parse board command #parse board command
......
...@@ -505,10 +505,19 @@ def export_trials_data(args): ...@@ -505,10 +505,19 @@ def export_trials_data(args):
# dframe = pd.DataFrame.from_records([parse_trial_data(t_data) for t_data in content]) # dframe = pd.DataFrame.from_records([parse_trial_data(t_data) for t_data in content])
# dframe.to_csv(args.csv_path, sep='\t') # dframe.to_csv(args.csv_path, sep='\t')
records = parse_trial_data(content) records = parse_trial_data(content)
with open(args.csv_path, 'w') as f_csv: if args.type == 'json':
writer = csv.DictWriter(f_csv, set.union(*[set(r.keys()) for r in records])) json_records = []
writer.writeheader() for trial in records:
writer.writerows(records) value = trial.pop('reward', None)
trial_id = trial.pop('id', None)
json_records.append({'parameter': trial, 'value': value, 'id': trial_id})
with open(args.path, 'w') as file:
if args.type == 'csv':
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in records]))
writer.writeheader()
writer.writerows(records)
else:
json.dump(json_records, file)
else: else:
print_error('Export failed...') print_error('Export failed...')
else: else:
......
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
import json import json
import os import os
from .rest_utils import rest_put, rest_get, check_rest_server_quick, check_response from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url from .url_utils import experiment_url, import_data_url
from .config_utils import Config from .config_utils import Config
from .common_utils import get_json_content from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename
from .launcher_utils import parse_time from .launcher_utils import parse_time
from .constants import REST_TIME_OUT from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
def validate_digit(value, start, end): def validate_digit(value, start, end):
'''validate if a digit is valid''' '''validate if a digit is valid'''
...@@ -39,6 +39,23 @@ def validate_file(path): ...@@ -39,6 +39,23 @@ def validate_file(path):
if not os.path.exists(path): if not os.path.exists(path):
raise FileNotFoundError('%s is not a valid file path' % path) raise FileNotFoundError('%s is not a valid file path' % path)
def validate_dispatcher(args):
'''validate if the dispatcher of the experiment supports importing data'''
nni_config = Config(get_config_filename(args)).get_config('experimentConfig')
if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'):
dispatcher_name = nni_config['tuner']['builtinTunerName']
elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'):
dispatcher_name = nni_config['advisor']['builtinAdvisorName']
else: # otherwise it should be a customized one
return
if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
if dispatcher_name in TUNERS_NO_NEED_TO_IMPORT_DATA:
print_warning("There is no need to import data for %s" % dispatcher_name)
exit(0)
else:
print_error("%s does not support importing addtional data" % dispatcher_name)
exit(1)
def load_search_space(path): def load_search_space(path):
'''load search space content''' '''load search space content'''
content = json.dumps(get_json_content(path)) content = json.dumps(get_json_content(path))
...@@ -71,7 +88,7 @@ def update_experiment_profile(args, key, value): ...@@ -71,7 +88,7 @@ def update_experiment_profile(args, key, value):
if response and check_response(response): if response and check_response(response):
return response return response
else: else:
print('ERROR: restful server is not running...') print_error('Restful server is not running...')
return None return None
def update_searchspace(args): def update_searchspace(args):
...@@ -80,18 +97,19 @@ def update_searchspace(args): ...@@ -80,18 +97,19 @@ def update_searchspace(args):
args.port = get_experiment_port(args) args.port = get_experiment_port(args)
if args.port is not None: if args.port is not None:
if update_experiment_profile(args, 'searchSpace', content): if update_experiment_profile(args, 'searchSpace', content):
print('INFO: update %s success!' % 'searchSpace') print_normal('Update %s success!' % 'searchSpace')
else: else:
print('ERROR: update %s failed!' % 'searchSpace') print_error('Update %s failed!' % 'searchSpace')
def update_concurrency(args): def update_concurrency(args):
validate_digit(args.value, 1, 1000) validate_digit(args.value, 1, 1000)
args.port = get_experiment_port(args) args.port = get_experiment_port(args)
if args.port is not None: if args.port is not None:
if update_experiment_profile(args, 'trialConcurrency', int(args.value)): if update_experiment_profile(args, 'trialConcurrency', int(args.value)):
print('INFO: update %s success!' % 'concurrency') print_normal('Update %s success!' % 'concurrency')
else: else:
print('ERROR: update %s failed!' % 'concurrency') print_error('Update %s failed!' % 'concurrency')
def update_duration(args): def update_duration(args):
#parse time, change time unit to seconds #parse time, change time unit to seconds
...@@ -99,13 +117,38 @@ def update_duration(args): ...@@ -99,13 +117,38 @@ def update_duration(args):
args.port = get_experiment_port(args) args.port = get_experiment_port(args)
if args.port is not None: if args.port is not None:
if update_experiment_profile(args, 'maxExecDuration', int(args.value)): if update_experiment_profile(args, 'maxExecDuration', int(args.value)):
print('INFO: update %s success!' % 'duration') print_normal('Update %s success!' % 'duration')
else: else:
print('ERROR: update %s failed!' % 'duration') print_error('Update %s failed!' % 'duration')
def update_trialnum(args): def update_trialnum(args):
validate_digit(args.value, 1, 999999999) validate_digit(args.value, 1, 999999999)
if update_experiment_profile(args, 'maxTrialNum', int(args.value)): if update_experiment_profile(args, 'maxTrialNum', int(args.value)):
print('INFO: update %s success!' % 'trialnum') print_normal('Update %s success!' % 'trialnum')
else: else:
print('ERROR: update %s failed!' % 'trialnum') print_error('Update %s failed!' % 'trialnum')
\ No newline at end of file
def import_data(args):
'''import additional data to the experiment'''
validate_file(args.filename)
validate_dispatcher(args)
content = load_search_space(args.filename)
args.port = get_experiment_port(args)
if args.port is not None:
if import_data_to_restful_server(args, content):
print_normal('Import data success!')
else:
print_error('Import data failed!')
def import_data_to_restful_server(args, content):
'''call restful server to import data to the experiment'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
if response and check_response(response):
return response
else:
print_error('Restful server is not running...')
return None
...@@ -29,6 +29,8 @@ EXPERIMENT_API = '/experiment' ...@@ -29,6 +29,8 @@ EXPERIMENT_API = '/experiment'
CLUSTER_METADATA_API = '/experiment/cluster-metadata' CLUSTER_METADATA_API = '/experiment/cluster-metadata'
IMPORT_DATA_API = '/experiment/import-data'
CHECK_STATUS_API = '/check-status' CHECK_STATUS_API = '/check-status'
TRIAL_JOBS_API = '/trial-jobs' TRIAL_JOBS_API = '/trial-jobs'
...@@ -46,6 +48,11 @@ def cluster_metadata_url(port): ...@@ -46,6 +48,11 @@ def cluster_metadata_url(port):
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API) return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API)
def import_data_url(port):
'''get import_data_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, IMPORT_DATA_API)
def experiment_url(port): def experiment_url(port):
'''get experiment_url''' '''get experiment_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API) return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment