Commit c8ef4141 authored by Zejun Lin's avatar Zejun Lin Committed by SparkSnail
Browse files

Implement API for user to import data and export data of type `json` or `csv` (#980)

parent 1d9b0a99
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
import { ExperimentProfile, TrialJobStatistics } from './manager'; import { ExperimentProfile, TrialJobStatistics } from './manager';
import { TrialJobDetail, TrialJobStatus } from './trainingService'; import { TrialJobDetail, TrialJobStatus } from './trainingService';
type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED' | 'ADD_HYPERPARAMETER'; type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED' | 'ADD_HYPERPARAMETER' | 'IMPORT_DATA';
type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM' | 'REQUEST_PARAMETER'; type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM' | 'REQUEST_PARAMETER';
interface ExperimentProfileRecord { interface ExperimentProfileRecord {
......
...@@ -99,6 +99,7 @@ abstract class Manager { ...@@ -99,6 +99,7 @@ abstract class Manager {
public abstract stopExperiment(): Promise<void>; public abstract stopExperiment(): Promise<void>;
public abstract getExperimentProfile(): Promise<ExperimentProfile>; public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>; public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract importData(data: string): Promise<void>;
public abstract addCustomizedTrialJob(hyperParams: string): Promise<void>; public abstract addCustomizedTrialJob(hyperParams: string): Promise<void>;
public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>; public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>;
......
...@@ -22,6 +22,7 @@ const INITIALIZE = 'IN'; ...@@ -22,6 +22,7 @@ const INITIALIZE = 'IN';
const REQUEST_TRIAL_JOBS = 'GE'; const REQUEST_TRIAL_JOBS = 'GE';
const REPORT_METRIC_DATA = 'ME'; const REPORT_METRIC_DATA = 'ME';
const UPDATE_SEARCH_SPACE = 'SS'; const UPDATE_SEARCH_SPACE = 'SS';
const IMPORT_DATA = 'FD'
const ADD_CUSTOMIZED_TRIAL_JOB = 'AD'; const ADD_CUSTOMIZED_TRIAL_JOB = 'AD';
const TRIAL_END = 'EN'; const TRIAL_END = 'EN';
const TERMINATE = 'TE'; const TERMINATE = 'TE';
...@@ -38,6 +39,7 @@ const TUNER_COMMANDS: Set<string> = new Set([ ...@@ -38,6 +39,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
REQUEST_TRIAL_JOBS, REQUEST_TRIAL_JOBS,
REPORT_METRIC_DATA, REPORT_METRIC_DATA,
UPDATE_SEARCH_SPACE, UPDATE_SEARCH_SPACE,
IMPORT_DATA,
ADD_CUSTOMIZED_TRIAL_JOB, ADD_CUSTOMIZED_TRIAL_JOB,
TERMINATE, TERMINATE,
PING, PING,
...@@ -62,6 +64,7 @@ export { ...@@ -62,6 +64,7 @@ export {
REQUEST_TRIAL_JOBS, REQUEST_TRIAL_JOBS,
REPORT_METRIC_DATA, REPORT_METRIC_DATA,
UPDATE_SEARCH_SPACE, UPDATE_SEARCH_SPACE,
IMPORT_DATA,
ADD_CUSTOMIZED_TRIAL_JOB, ADD_CUSTOMIZED_TRIAL_JOB,
TRIAL_END, TRIAL_END,
TERMINATE, TERMINATE,
......
...@@ -38,7 +38,7 @@ import { ...@@ -38,7 +38,7 @@ import {
import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getLogLevel } from '../common/utils'; import { delay, getCheckpointDir, getExperimentRootDir, getLogDir, getMsgDispatcherCommand, mkDirP, getLogLevel } from '../common/utils';
import { import {
ADD_CUSTOMIZED_TRIAL_JOB, INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, PING, ADD_CUSTOMIZED_TRIAL_JOB, INITIALIZE, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, PING,
REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE REPORT_METRIC_DATA, REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE, IMPORT_DATA
} from './commands'; } from './commands';
import { createDispatcherInterface, IpcInterface } from './ipcInterface'; import { createDispatcherInterface, IpcInterface } from './ipcInterface';
...@@ -99,6 +99,17 @@ class NNIManager implements Manager { ...@@ -99,6 +99,17 @@ class NNIManager implements Manager {
return this.storeExperimentProfile(); return this.storeExperimentProfile();
} }
public importData(data: string): Promise<void> {
if (this.dispatcher === undefined) {
return Promise.reject(
new Error('tuner has not been setup')
);
}
this.dispatcher.sendCommand(IMPORT_DATA, data);
return this.dataStore.storeTrialJobEvent('IMPORT_DATA', '', data);
}
public addCustomizedTrialJob(hyperParams: string): Promise<void> { public addCustomizedTrialJob(hyperParams: string): Promise<void> {
if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
return Promise.reject( return Promise.reject(
......
...@@ -63,6 +63,7 @@ class NNIRestHandler { ...@@ -63,6 +63,7 @@ class NNIRestHandler {
this.checkStatus(router); this.checkStatus(router);
this.getExperimentProfile(router); this.getExperimentProfile(router);
this.updateExperimentProfile(router); this.updateExperimentProfile(router);
this.importData(router);
this.startExperiment(router); this.startExperiment(router);
this.getTrialJobStatistics(router); this.getTrialJobStatistics(router);
this.setClusterMetaData(router); this.setClusterMetaData(router);
...@@ -144,6 +145,16 @@ class NNIRestHandler { ...@@ -144,6 +145,16 @@ class NNIRestHandler {
}); });
}); });
} }
private importData(router: Router): void {
router.post('/experiment/import-data', (req: Request, res: Response) => {
this.nniManager.importData(JSON.stringify(req.body)).then(() => {
res.send();
}).catch((err: Error) => {
this.handle_error(err, res);
});
});
}
private startExperiment(router: Router): void { private startExperiment(router: Router): void {
router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => { router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => {
......
...@@ -46,6 +46,9 @@ export class MockedNNIManager extends Manager { ...@@ -46,6 +46,9 @@ export class MockedNNIManager extends Manager {
public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> { public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> {
return Promise.resolve(); return Promise.resolve();
} }
public importData(data: string): Promise<void> {
return Promise.resolve();
}
public getTrialJobStatistics(): Promise<TrialJobStatistics[]> { public getTrialJobStatistics(): Promise<TrialJobStatistics[]> {
const deferred: Deferred<TrialJobStatistics[]> = new Deferred<TrialJobStatistics[]>(); const deferred: Deferred<TrialJobStatistics[]> = new Deferred<TrialJobStatistics[]>();
deferred.resolve([{ deferred.resolve([{
......
...@@ -109,18 +109,24 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -109,18 +109,24 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
def handle_import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
# data: parameters # data: parameters
id_ = _create_parameter_id() id_ = _create_parameter_id()
_customized_parameter_ids.add(id_) _customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True)) send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
""" """
:param data: a dict received from nni_manager, which contains: data: a dict received from nni_manager, which contains:
- 'parameter_id': id of the trial - 'parameter_id': id of the trial
- 'value': metric value reported by nni.report_final_result() - 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
if data['type'] == 'FINAL': if data['type'] == 'FINAL':
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
...@@ -135,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -135,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_trial_end(self, data): def handle_trial_end(self, data):
""" """
data: it has three keys: trial_job_id, event, hyper_params data: it has three keys: trial_job_id, event, hyper_params
trial_job_id: the id generated by training service - trial_job_id: the id generated by training service
event: the job's state - event: the job's state
hyper_params: the hyperparameters generated and returned by tuner - hyper_params: the hyperparameters generated and returned by tuner
""" """
trial_job_id = data['trial_job_id'] trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id) _ended_trials.add(trial_job_id)
......
...@@ -144,6 +144,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -144,6 +144,7 @@ class MsgDispatcherBase(Recoverable):
CommandType.Initialize: self.handle_initialize, CommandType.Initialize: self.handle_initialize,
CommandType.RequestTrialJobs: self.handle_request_trial_jobs, CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space, CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.ImportData: self.handle_import_data,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial, CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
# Tunner/Assessor commands: # Tunner/Assessor commands:
...@@ -168,6 +169,9 @@ class MsgDispatcherBase(Recoverable): ...@@ -168,6 +169,9 @@ class MsgDispatcherBase(Recoverable):
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
raise NotImplementedError('handle_update_search_space not implemented') raise NotImplementedError('handle_update_search_space not implemented')
def handle_import_data(self, data):
raise NotImplementedError('handle_import_data not implemented')
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
raise NotImplementedError('handle_add_customized_trial not implemented') raise NotImplementedError('handle_add_customized_trial not implemented')
......
...@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase): ...@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
return True return True
def handle_import_data(self, data):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
return True
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
# data: parameters # data: parameters
id_ = _create_parameter_id() id_ = _create_parameter_id()
......
...@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable): ...@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable):
""" """
raise NotImplementedError('Tuner: update_search_space not implemented') raise NotImplementedError('Tuner: update_search_space not implemented')
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def load_checkpoint(self): def load_checkpoint(self):
"""Load the checkpoint of tuner. """Load the checkpoint of tuner.
path: checkpoint directory for tuner path: checkpoint directory for tuner
......
...@@ -30,6 +30,7 @@ class CommandType(Enum): ...@@ -30,6 +30,7 @@ class CommandType(Enum):
RequestTrialJobs = b'GE' RequestTrialJobs = b'GE'
ReportMetricData = b'ME' ReportMetricData = b'ME'
UpdateSearchSpace = b'SS' UpdateSearchSpace = b'SS'
ImportData = b'FD'
AddCustomizedTrialJob = b'AD' AddCustomizedTrialJob = b'AD'
TrialEnd = b'EN' TrialEnd = b'EN'
Terminate = b'TE' Terminate = b'TE'
......
...@@ -98,6 +98,12 @@ class Tuner(Recoverable): ...@@ -98,6 +98,12 @@ class Tuner(Recoverable):
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path) _logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def _on_exit(self): def _on_exit(self):
pass pass
......
...@@ -80,6 +80,20 @@ PACKAGE_REQUIREMENTS = { ...@@ -80,6 +80,20 @@ PACKAGE_REQUIREMENTS = {
'BOHB': 'bohb_advisor' 'BOHB': 'bohb_advisor'
} }
TUNERS_SUPPORTING_IMPORT_DATA = {
'TPE',
'Anneal',
'GridSearch',
'MetisTuner',
'BOHB'
}
TUNERS_NO_NEED_TO_IMPORT_DATA = {
'Random',
'Batch_tuner',
'Hyperband'
}
COLOR_RED_FORMAT = '\033[1;31;31m%s\033[0m' COLOR_RED_FORMAT = '\033[1;31;31m%s\033[0m'
COLOR_GREEN_FORMAT = '\033[1;32;32m%s\033[0m' COLOR_GREEN_FORMAT = '\033[1;32;32m%s\033[0m'
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
import argparse import argparse
import pkg_resources import pkg_resources
from .launcher import create_experiment, resume_experiment from .launcher import create_experiment, resume_experiment
from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum from .updater import update_searchspace, update_concurrency, update_duration, update_trialnum, import_data
from .nnictl_utils import * from .nnictl_utils import *
from .package_management import * from .package_management import *
from .constants import * from .constants import *
...@@ -101,10 +101,6 @@ def parse_args(): ...@@ -101,10 +101,6 @@ def parse_args():
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment') parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed') parser_trial_kill.add_argument('--trial_id', '-T', required=True, dest='trial_id', help='the id of trial to be killed')
parser_trial_kill.set_defaults(func=trial_kill) parser_trial_kill.set_defaults(func=trial_kill)
parser_trial_export = parser_trial_subparsers.add_parser('export', help='export trial job results to csv')
parser_trial_export.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_export.add_argument('--file', '-f', required=True, dest='csv_path', help='target csv file path')
parser_trial_export.set_defaults(func=export_trials_data)
#parse experiment command #parse experiment command
parser_experiment = subparsers.add_parser('experiment', help='get experiment information') parser_experiment = subparsers.add_parser('experiment', help='get experiment information')
...@@ -119,6 +115,17 @@ def parse_args(): ...@@ -119,6 +115,17 @@ def parse_args():
parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids') parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids')
parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments') parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments')
parser_experiment_list.set_defaults(func=experiment_list) parser_experiment_list.set_defaults(func=experiment_list)
#import tuning data
parser_import_data = parser_experiment_subparsers.add_parser('import', help='import additional data')
parser_import_data.add_argument('id', nargs='?', help='the id of experiment')
parser_import_data.add_argument('--filename', '-f', required=True)
parser_import_data.set_defaults(func=import_data)
#export trial data
parser_trial_export = parser_experiment_subparsers.add_parser('export', help='export trial job results to csv or json')
parser_trial_export.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_export.add_argument('--type', '-t', choices=['json', 'csv'], required=True, dest='type', help='target file type')
parser_trial_export.add_argument('--filename', '-f', required=True, dest='path', help='target file path')
parser_trial_export.set_defaults(func=export_trials_data)
#TODO:finish webui function #TODO:finish webui function
#parse board command #parse board command
......
...@@ -505,10 +505,19 @@ def export_trials_data(args): ...@@ -505,10 +505,19 @@ def export_trials_data(args):
# dframe = pd.DataFrame.from_records([parse_trial_data(t_data) for t_data in content]) # dframe = pd.DataFrame.from_records([parse_trial_data(t_data) for t_data in content])
# dframe.to_csv(args.csv_path, sep='\t') # dframe.to_csv(args.csv_path, sep='\t')
records = parse_trial_data(content) records = parse_trial_data(content)
with open(args.csv_path, 'w') as f_csv: if args.type == 'json':
writer = csv.DictWriter(f_csv, set.union(*[set(r.keys()) for r in records])) json_records = []
writer.writeheader() for trial in records:
writer.writerows(records) value = trial.pop('reward', None)
trial_id = trial.pop('id', None)
json_records.append({'parameter': trial, 'value': value, 'id': trial_id})
with open(args.path, 'w') as file:
if args.type == 'csv':
writer = csv.DictWriter(file, set.union(*[set(r.keys()) for r in records]))
writer.writeheader()
writer.writerows(records)
else:
json.dump(json_records, file)
else: else:
print_error('Export failed...') print_error('Export failed...')
else: else:
......
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
import json import json
import os import os
from .rest_utils import rest_put, rest_get, check_rest_server_quick, check_response from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, check_response
from .url_utils import experiment_url from .url_utils import experiment_url, import_data_url
from .config_utils import Config from .config_utils import Config
from .common_utils import get_json_content from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename
from .launcher_utils import parse_time from .launcher_utils import parse_time
from .constants import REST_TIME_OUT from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
def validate_digit(value, start, end): def validate_digit(value, start, end):
'''validate if a digit is valid''' '''validate if a digit is valid'''
...@@ -39,6 +39,23 @@ def validate_file(path): ...@@ -39,6 +39,23 @@ def validate_file(path):
if not os.path.exists(path): if not os.path.exists(path):
raise FileNotFoundError('%s is not a valid file path' % path) raise FileNotFoundError('%s is not a valid file path' % path)
def validate_dispatcher(args):
'''validate if the dispatcher of the experiment supports importing data'''
nni_config = Config(get_config_filename(args)).get_config('experimentConfig')
if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'):
dispatcher_name = nni_config['tuner']['builtinTunerName']
elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'):
dispatcher_name = nni_config['advisor']['builtinAdvisorName']
else: # otherwise it should be a customized one
return
if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
if dispatcher_name in TUNERS_NO_NEED_TO_IMPORT_DATA:
print_warning("There is no need to import data for %s" % dispatcher_name)
exit(0)
else:
print_error("%s does not support importing addtional data" % dispatcher_name)
exit(1)
def load_search_space(path): def load_search_space(path):
'''load search space content''' '''load search space content'''
content = json.dumps(get_json_content(path)) content = json.dumps(get_json_content(path))
...@@ -71,7 +88,7 @@ def update_experiment_profile(args, key, value): ...@@ -71,7 +88,7 @@ def update_experiment_profile(args, key, value):
if response and check_response(response): if response and check_response(response):
return response return response
else: else:
print('ERROR: restful server is not running...') print_error('Restful server is not running...')
return None return None
def update_searchspace(args): def update_searchspace(args):
...@@ -80,18 +97,19 @@ def update_searchspace(args): ...@@ -80,18 +97,19 @@ def update_searchspace(args):
args.port = get_experiment_port(args) args.port = get_experiment_port(args)
if args.port is not None: if args.port is not None:
if update_experiment_profile(args, 'searchSpace', content): if update_experiment_profile(args, 'searchSpace', content):
print('INFO: update %s success!' % 'searchSpace') print_normal('Update %s success!' % 'searchSpace')
else: else:
print('ERROR: update %s failed!' % 'searchSpace') print_error('Update %s failed!' % 'searchSpace')
def update_concurrency(args): def update_concurrency(args):
validate_digit(args.value, 1, 1000) validate_digit(args.value, 1, 1000)
args.port = get_experiment_port(args) args.port = get_experiment_port(args)
if args.port is not None: if args.port is not None:
if update_experiment_profile(args, 'trialConcurrency', int(args.value)): if update_experiment_profile(args, 'trialConcurrency', int(args.value)):
print('INFO: update %s success!' % 'concurrency') print_normal('Update %s success!' % 'concurrency')
else: else:
print('ERROR: update %s failed!' % 'concurrency') print_error('Update %s failed!' % 'concurrency')
def update_duration(args): def update_duration(args):
#parse time, change time unit to seconds #parse time, change time unit to seconds
...@@ -99,13 +117,38 @@ def update_duration(args): ...@@ -99,13 +117,38 @@ def update_duration(args):
args.port = get_experiment_port(args) args.port = get_experiment_port(args)
if args.port is not None: if args.port is not None:
if update_experiment_profile(args, 'maxExecDuration', int(args.value)): if update_experiment_profile(args, 'maxExecDuration', int(args.value)):
print('INFO: update %s success!' % 'duration') print_normal('Update %s success!' % 'duration')
else: else:
print('ERROR: update %s failed!' % 'duration') print_error('Update %s failed!' % 'duration')
def update_trialnum(args): def update_trialnum(args):
validate_digit(args.value, 1, 999999999) validate_digit(args.value, 1, 999999999)
if update_experiment_profile(args, 'maxTrialNum', int(args.value)): if update_experiment_profile(args, 'maxTrialNum', int(args.value)):
print('INFO: update %s success!' % 'trialnum') print_normal('Update %s success!' % 'trialnum')
else: else:
print('ERROR: update %s failed!' % 'trialnum') print_error('Update %s failed!' % 'trialnum')
\ No newline at end of file
def import_data(args):
'''import additional data to the experiment'''
validate_file(args.filename)
validate_dispatcher(args)
content = load_search_space(args.filename)
args.port = get_experiment_port(args)
if args.port is not None:
if import_data_to_restful_server(args, content):
print_normal('Import data success!')
else:
print_error('Import data failed!')
def import_data_to_restful_server(args, content):
'''call restful server to import data to the experiment'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
if response and check_response(response):
return response
else:
print_error('Restful server is not running...')
return None
...@@ -29,6 +29,8 @@ EXPERIMENT_API = '/experiment' ...@@ -29,6 +29,8 @@ EXPERIMENT_API = '/experiment'
CLUSTER_METADATA_API = '/experiment/cluster-metadata' CLUSTER_METADATA_API = '/experiment/cluster-metadata'
IMPORT_DATA_API = '/experiment/import-data'
CHECK_STATUS_API = '/check-status' CHECK_STATUS_API = '/check-status'
TRIAL_JOBS_API = '/trial-jobs' TRIAL_JOBS_API = '/trial-jobs'
...@@ -46,6 +48,11 @@ def cluster_metadata_url(port): ...@@ -46,6 +48,11 @@ def cluster_metadata_url(port):
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API) return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CLUSTER_METADATA_API)
def import_data_url(port):
'''get import_data_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, IMPORT_DATA_API)
def experiment_url(port): def experiment_url(port):
'''get experiment_url''' '''get experiment_url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API) return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, EXPERIMENT_API)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment