Unverified Commit a4802083 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix experiment import bug and add it cases: experiment import (#2878)

parent d1c63562
......@@ -67,7 +67,7 @@ It is easy to use NNI in your scikit-learn code, there are only a few steps.
"kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]},
"degree": {"_type":"choice","_value":[1, 2, 3, 4]},
"gamma": {"_type":"uniform","_value":[0.01, 0.1]},
"coef0 ": {"_type":"uniform","_value":[0.01, 0.1]}
"coef0": {"_type":"uniform","_value":[0.01, 0.1]}
}
```
......
......@@ -3,5 +3,5 @@
"kernel": {"_type":"choice","_value":["linear", "rbf", "poly", "sigmoid"]},
"degree": {"_type":"choice","_value":[1, 2, 3, 4]},
"gamma": {"_type":"uniform","_value":[0.01, 0.1]},
"coef0 ": {"_type":"uniform","_value":[0.01, 0.1]}
"coef0": {"_type":"uniform","_value":[0.01, 0.1]}
}
\ No newline at end of file
......@@ -87,6 +87,7 @@ abstract class Manager {
public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract importData(data: string): Promise<void>;
public abstract getImportedData(): Promise<string[]>;
public abstract exportData(): Promise<string>;
public abstract addCustomizedTrialJob(hyperParams: string): Promise<number>;
......
......@@ -108,6 +108,10 @@ class NNIManager implements Manager {
return this.dataStore.storeTrialJobEvent('IMPORT_DATA', '', data);
}
public getImportedData(): Promise<string[]> {
return this.dataStore.getImportedData();
}
public async exportData(): Promise<string> {
return this.dataStore.exportTrialHpConfigs();
}
......
......@@ -47,6 +47,7 @@ class NNIRestHandler {
this.getExperimentProfile(router);
this.updateExperimentProfile(router);
this.importData(router);
this.getImportedData(router);
this.startExperiment(router);
this.getTrialJobStatistics(router);
this.setClusterMetaData(router);
......@@ -143,6 +144,16 @@ class NNIRestHandler {
});
}
private getImportedData(router: Router): void {
router.get('/experiment/imported-data', (req: Request, res: Response) => {
this.nniManager.getImportedData().then((importedData: string[]) => {
res.send(JSON.stringify(importedData));
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private startExperiment(router: Router): void {
router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => {
if (isNewExperiment()) {
......
......@@ -33,6 +33,10 @@ export class MockedNNIManager extends Manager {
public importData(data: string): Promise<void> {
return Promise.resolve();
}
public getImportedData(): Promise<string[]> {
const ret: string[] = ["1", "2"];
return Promise.resolve(ret);
}
public async exportData(): Promise<string> {
const ret: string = '';
return Promise.resolve(ret);
......
......@@ -114,6 +114,7 @@ class MsgDispatcher(MsgDispatcherBase):
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for entry in data:
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value'])
entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data)
......
......@@ -135,6 +135,13 @@ testCases:
validator:
class: ExportValidator
- name: experiment-import
configFile: test/config/nnictl_experiment/sklearn-classification.yml
validator:
class: ImportValidator
kwargs:
import_data_file_path: config/nnictl_experiment/test_import.json
- name: nnicli
configFile: test/config/examples/sklearn-regression.yml
config:
......
authorName: nni
experimentName: default_test
maxExecDuration: 5m
maxTrialNum: 4
trialConcurrency: 2
searchSpacePath: ../../../examples/trials/sklearn/classification/search_space.json
tuner:
builtinTunerName: TPE
assessor:
builtinAssessorName: Medianstop
classArgs:
optimize_mode: maximize
trial:
codeDir: ../../../examples/trials/sklearn/classification
command: python3 main.py
gpuNum: 0
useAnnotation: false
multiPhase: false
multiThread: false
trainingServicePlatform: local
[
{"parameter": {"C": 0.15940134774738896, "kernel": "sigmoid", "degree": 3, "gamma": 0.07295826917955316, "coef0": 0.0978204758732429}, "value": 0.6},
{"parameter": {"C": 0.5556430724708544, "kernel": "linear", "degree": 3, "gamma": 0.04957496655414671, "coef0": 0.08520868779907687}, "value": 0.7}
]
......@@ -24,6 +24,7 @@ EXPERIMENT_URL = API_ROOT_URL + '/experiment'
STATUS_URL = API_ROOT_URL + '/check-status'
TRIAL_JOBS_URL = API_ROOT_URL + '/trial-jobs'
METRICS_URL = API_ROOT_URL + '/metric-data'
GET_IMPORTED_DATA_URL = API_ROOT_URL + '/experiment/imported-data'
def read_last_line(file_name):
'''read last line of a file and return None if file not found'''
......
......@@ -7,7 +7,8 @@ import subprocess
import json
import requests
from nnicli import Experiment
from utils import METRICS_URL
from nni_cmd.updater import load_search_space
from utils import METRICS_URL, GET_IMPORTED_DATA_URL
class ITValidator:
......@@ -33,6 +34,17 @@ class ExportValidator(ITValidator):
print('\n\n')
remove('report.json')
class ImportValidator(ITValidator):
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
exp_id = osp.split(experiment_dir)[-1]
import_data_file_path = kwargs.get('import_data_file_path')
proc = subprocess.run(['nnictl', 'experiment', 'import', exp_id, '-f', import_data_file_path])
assert proc.returncode == 0, \
'`nnictl experiment import {0} -f {1}` failed with code {2}'.format(exp_id, import_data_file_path, proc.returncode)
imported_data = requests.get(GET_IMPORTED_DATA_URL).json()
origin_data = load_search_space(import_data_file_path).replace(' ', '')
assert origin_data in imported_data
class MetricsValidator(ITValidator):
def __call__(self, rest_endpoint, experiment_dir, nni_source_dir, **kwargs):
self.check_metrics(nni_source_dir, **kwargs)
......
......@@ -7,7 +7,7 @@ from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick,
from .url_utils import experiment_url, import_data_url
from .config_utils import Config
from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import get_experiment_port, get_config_filename
from .nnictl_utils import get_experiment_port, get_config_filename, detect_process
from .launcher_utils import parse_time
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
......@@ -115,7 +115,19 @@ def import_data(args):
validate_file(args.filename)
validate_dispatcher(args)
content = load_search_space(args.filename)
args.port = get_experiment_port(args)
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
running, _ = check_rest_server_quick(rest_port)
if not running:
print_error('Restful server is not running')
return
args.port = rest_port
if args.port is not None:
if import_data_to_restful_server(args, content):
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment