Unverified Commit a4802083 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix experiment import bug and add it cases: experiment import (#2878)

parent d1c63562
...@@ -67,7 +67,7 @@ It is easy to use NNI in your scikit-learn code, there are only a few steps. ...@@ -67,7 +67,7 @@ It is easy to use NNI in your scikit-learn code, there are only a few steps.
"kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]}, "kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]},
"degree": {"_type":"choice","_value":[1, 2, 3, 4]}, "degree": {"_type":"choice","_value":[1, 2, 3, 4]},
"gamma": {"_type":"uniform","_value":[0.01, 0.1]}, "gamma": {"_type":"uniform","_value":[0.01, 0.1]},
"coef0 ": {"_type":"uniform","_value":[0.01, 0.1]} "coef0": {"_type":"uniform","_value":[0.01, 0.1]}
} }
``` ```
......
...@@ -3,5 +3,5 @@ ...@@ -3,5 +3,5 @@
"kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]}, "kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]},
"degree": {"_type":"choice","_value":[1, 2, 3, 4]}, "degree": {"_type":"choice","_value":[1, 2, 3, 4]},
"gamma": {"_type":"uniform","_value":[0.01, 0.1]}, "gamma": {"_type":"uniform","_value":[0.01, 0.1]},
"coef0 ": {"_type":"uniform","_value":[0.01, 0.1]} "coef0": {"_type":"uniform","_value":[0.01, 0.1]}
} }
\ No newline at end of file
...@@ -87,6 +87,7 @@ abstract class Manager { ...@@ -87,6 +87,7 @@ abstract class Manager {
public abstract getExperimentProfile(): Promise<ExperimentProfile>; public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>; public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract importData(data: string): Promise<void>; public abstract importData(data: string): Promise<void>;
public abstract getImportedData(): Promise<string[]>;
public abstract exportData(): Promise<string>; public abstract exportData(): Promise<string>;
public abstract addCustomizedTrialJob(hyperParams: string): Promise<number>; public abstract addCustomizedTrialJob(hyperParams: string): Promise<number>;
......
...@@ -108,6 +108,10 @@ class NNIManager implements Manager { ...@@ -108,6 +108,10 @@ class NNIManager implements Manager {
return this.dataStore.storeTrialJobEvent('IMPORT_DATA', '', data); return this.dataStore.storeTrialJobEvent('IMPORT_DATA', '', data);
} }
public getImportedData(): Promise<string[]> {
return this.dataStore.getImportedData();
}
public async exportData(): Promise<string> { public async exportData(): Promise<string> {
return this.dataStore.exportTrialHpConfigs(); return this.dataStore.exportTrialHpConfigs();
} }
......
...@@ -47,6 +47,7 @@ class NNIRestHandler { ...@@ -47,6 +47,7 @@ class NNIRestHandler {
this.getExperimentProfile(router); this.getExperimentProfile(router);
this.updateExperimentProfile(router); this.updateExperimentProfile(router);
this.importData(router); this.importData(router);
this.getImportedData(router);
this.startExperiment(router); this.startExperiment(router);
this.getTrialJobStatistics(router); this.getTrialJobStatistics(router);
this.setClusterMetaData(router); this.setClusterMetaData(router);
...@@ -143,6 +144,16 @@ class NNIRestHandler { ...@@ -143,6 +144,16 @@ class NNIRestHandler {
}); });
} }
private getImportedData(router: Router): void {
router.get('/experiment/imported-data', (req: Request, res: Response) => {
this.nniManager.getImportedData().then((importedData: string[]) => {
res.send(JSON.stringify(importedData));
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private startExperiment(router: Router): void { private startExperiment(router: Router): void {
router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => { router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => {
if (isNewExperiment()) { if (isNewExperiment()) {
......
...@@ -33,6 +33,10 @@ export class MockedNNIManager extends Manager { ...@@ -33,6 +33,10 @@ export class MockedNNIManager extends Manager {
public importData(data: string): Promise<void> { public importData(data: string): Promise<void> {
return Promise.resolve(); return Promise.resolve();
} }
public getImportedData(): Promise<string[]> {
const ret: string[] = ["1", "2"];
return Promise.resolve(ret);
}
public async exportData(): Promise<string> { public async exportData(): Promise<string> {
const ret: string = ''; const ret: string = '';
return Promise.resolve(ret); return Promise.resolve(ret);
......
...@@ -114,6 +114,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -114,6 +114,7 @@ class MsgDispatcher(MsgDispatcherBase):
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
""" """
for entry in data: for entry in data:
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value'])
entry['value'] = json_tricks.loads(entry['value']) entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data) self.tuner.import_data(data)
......
...@@ -135,6 +135,13 @@ testCases: ...@@ -135,6 +135,13 @@ testCases:
validator: validator:
class: ExportValidator class: ExportValidator
- name: experiment-import
configFile: test/config/nnictl_experiment/sklearn-classification.yml
validator:
class: ImportValidator
kwargs:
import_data_file_path: config/nnictl_experiment/test_import.json
- name: nnicli - name: nnicli
configFile: test/config/examples/sklearn-regression.yml configFile: test/config/examples/sklearn-regression.yml
config: config:
......
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 4
trialConcurrency: 2
searchSpacePath: ../../../examples/trials/sklearn/classification/search_space.json
tuner:
builtinTunerName: TPE
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ../../../examples/trials/sklearn/classification
command: python3 main.py
gpuNum: 0
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
[
{"parameter": {"C": 0.15940134774738896, "kernel": "sigmoid", "degree": 3, "gamma": 0.07295826917955316, "coef0": 0.0978204758732429}, "value": 0.6},
{"parameter": {"C": 0.5556430724708544, "kernel": "linear", "degree": 3, "gamma": 0.04957496655414671, "coef0": 0.08520868779907687}, "value": 0.7}
]
...@@ -24,6 +24,7 @@ EXPERIMENT_URL = API_ROOT_URL + '/experiment' ...@@ -24,6 +24,7 @@ EXPERIMENT_URL = API_ROOT_URL + '/experiment'
STATUS_URL = API_ROOT_URL + '/check-status' STATUS_URL = API_ROOT_URL + '/check-status'
TRIAL_JOBS_URL = API_ROOT_URL + '/trial-jobs' TRIAL_JOBS_URL = API_ROOT_URL + '/trial-jobs'
METRICS_URL = API_ROOT_URL + '/metric-data' METRICS_URL = API_ROOT_URL + '/metric-data'
GET_IMPORTED_DATA_URL = API_ROOT_URL + '/experiment/imported-data'
def read_last_line(file_name): def read_last_line(file_name):
'''read last line of a file and return None if file not found''' '''read last line of a file and return None if file not found'''
......
...@@ -7,7 +7,8 @@ import subprocess ...@@ -7,7 +7,8 @@ import subprocess
import json import json
import requests import requests
from nnicli import Experiment from nnicli import Experiment
from utils import METRICS_URL from nni_cmd.updater import load_search_space
from utils import METRICS_URL, GET_IMPORTED_DATA_URL
class ITValidator: class ITValidator:
...@@ -33,6 +34,17 @@ class ExportValidator(ITValidator): ...@@ -33,6 +34,17 @@ class ExportValidator(ITValidator):
print('\n\n') print('\n\n')
remove('report.json') remove('report.json')
class ImportValidator(ITValidator):
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
exp_id = osp.split(experiment_dir)[-1]
import_data_file_path = kwargs.get('import_data_file_path')
proc = subprocess.run(['nnictl', 'experiment', 'import', exp_id, '-f', import_data_file_path])
assert proc.returncode == 0, \
'`nnictl experiment import {0} -f {1}` failed with code {2}'.format(exp_id, import_data_file_path, proc.returncode)
imported_data = requests.get(GET_IMPORTED_DATA_URL).json()
origin_data = load_search_space(import_data_file_path).replace(' ', '')
assert origin_data in imported_data
class MetricsValidator(ITValidator): class MetricsValidator(ITValidator):
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs): def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
self.check_metrics(nni_source_dir, **kwargs) self.check_metrics(nni_source_dir, **kwargs)
......
...@@ -7,7 +7,7 @@ from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, ...@@ -7,7 +7,7 @@ from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick,
from .url_utils import experiment_url, import_data_url from .url_utils import experiment_url, import_data_url
from .config_utils import Config from .config_utils import Config
from .common_utils import get_json_content, print_normal, print_error, print_warning from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time from .launcher_utils import parse_time
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
...@@ -115,7 +115,19 @@ def import_data(args): ...@@ -115,7 +115,19 @@ def import_data(args):
validate_file(args.filename) validate_file(args.filename)
validate_dispatcher(args) validate_dispatcher(args)
content = load_search_space(args.filename) content = load_search_space(args.filename)
args.port = get_experiment_port(args)
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, _ = check_rest_server_quick(rest_port)
if not running:
print_error('Restful server is not running')
return
args.port = rest_port
if args.port is not None: if args.port is not None:
if import_data_to_restful_server(args, content): if import_data_to_restful_server(args, content):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment