Unverified Commit 75028bd7 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #235 from microsoft/master

merge master
parents 1d74ae5e 2e42d1d8
...@@ -31,7 +31,6 @@ echarts.registerTheme('my_theme', { ...@@ -31,7 +31,6 @@ echarts.registerTheme('my_theme', {
color: '#3c8dbc' color: '#3c8dbc'
}); });
interface TableListProps { interface TableListProps {
pageSize: number; pageSize: number;
tableSource: Array<TableRecord>; tableSource: Array<TableRecord>;
...@@ -142,7 +141,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -142,7 +141,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
key: 'sequenceId', key: 'sequenceId',
fieldName: 'sequenceId', fieldName: 'sequenceId',
minWidth: 80, minWidth: 80,
maxWidth: 120, maxWidth: 240,
className: 'tableHead', className: 'tableHead',
data: 'string', data: 'string',
onColumnClick: this.onColumnClick, onColumnClick: this.onColumnClick,
...@@ -166,7 +165,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -166,7 +165,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
key: 'startTime', key: 'startTime',
fieldName: 'startTime', fieldName: 'startTime',
minWidth: 150, minWidth: 150,
maxWidth: 200, maxWidth: 400,
isResizable: true, isResizable: true,
data: 'number', data: 'number',
onColumnClick: this.onColumnClick, onColumnClick: this.onColumnClick,
...@@ -179,8 +178,8 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -179,8 +178,8 @@ class TableList extends React.Component<TableListProps, TableListState> {
name: 'End Time', name: 'End Time',
key: 'endTime', key: 'endTime',
fieldName: 'endTime', fieldName: 'endTime',
minWidth: 150, minWidth: 200,
maxWidth: 200, maxWidth: 400,
isResizable: true, isResizable: true,
data: 'number', data: 'number',
onColumnClick: this.onColumnClick, onColumnClick: this.onColumnClick,
...@@ -194,7 +193,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -194,7 +193,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
key: 'duration', key: 'duration',
fieldName: 'duration', fieldName: 'duration',
minWidth: 150, minWidth: 150,
maxWidth: 200, maxWidth: 300,
isResizable: true, isResizable: true,
data: 'number', data: 'number',
onColumnClick: this.onColumnClick, onColumnClick: this.onColumnClick,
...@@ -209,7 +208,7 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -209,7 +208,7 @@ class TableList extends React.Component<TableListProps, TableListState> {
fieldName: 'status', fieldName: 'status',
className: 'tableStatus', className: 'tableStatus',
minWidth: 150, minWidth: 150,
maxWidth: 200, maxWidth: 250,
isResizable: true, isResizable: true,
data: 'string', data: 'string',
onColumnClick: this.onColumnClick, onColumnClick: this.onColumnClick,
...@@ -241,17 +240,26 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -241,17 +240,26 @@ class TableList extends React.Component<TableListProps, TableListState> {
// support intermediate result is dict because the last intermediate result is // support intermediate result is dict because the last intermediate result is
// final result in a succeed trial, it may be a dict. // final result in a succeed trial, it may be a dict.
// get intermediate result dict keys array // get intermediate result dict keys array
let otherkeys: string[] = ['default']; const { intermediateKey } = this.state;
const otherkeys: string[] = [ ];
if (res.data.length !== 0) { if (res.data.length !== 0) {
otherkeys = Object.keys(parseMetrics(res.data[0].data)); // just add type=number keys
const intermediateMetrics = parseMetrics(res.data[0].data);
for(const key in intermediateMetrics){
if(typeof intermediateMetrics[key] === 'number') {
otherkeys.push(key);
}
}
} }
// intermediateArr just store default val // intermediateArr just store default val
Object.keys(res.data).map(item => { Object.keys(res.data).map(item => {
const temp = parseMetrics(res.data[item].data); if (res.data[item].type === 'PERIODICAL') {
if (typeof temp === 'object') { const temp = parseMetrics(res.data[item].data);
intermediateArr.push(temp.default); if (typeof temp === 'object') {
} else { intermediateArr.push(temp[intermediateKey]);
intermediateArr.push(temp); } else {
intermediateArr.push(temp);
}
} }
}); });
const intermediate = intermediateGraphOption(intermediateArr, id); const intermediate = intermediateGraphOption(intermediateArr, id);
...@@ -276,11 +284,13 @@ class TableList extends React.Component<TableListProps, TableListState> { ...@@ -276,11 +284,13 @@ class TableList extends React.Component<TableListProps, TableListState> {
// just watch default key-val // just watch default key-val
if (isShowDefault === true) { if (isShowDefault === true) {
Object.keys(intermediateData).map(item => { Object.keys(intermediateData).map(item => {
const temp = parseMetrics(intermediateData[item].data); if (intermediateData[item].type === 'PERIODICAL') {
if (typeof temp === 'object') { const temp = parseMetrics(intermediateData[item].data);
intermediateArr.push(temp[value]); if (typeof temp === 'object') {
} else { intermediateArr.push(temp[value]);
intermediateArr.push(temp); } else {
intermediateArr.push(temp);
}
} }
}); });
} else { } else {
......
...@@ -45,6 +45,10 @@ function parseMetrics(metricData: string): any { ...@@ -45,6 +45,10 @@ function parseMetrics(metricData: string): any {
} }
} }
const isArrayType = (list: any): boolean | undefined => {
return Array.isArray(list);
}
// get final result value // get final result value
// draw Accuracy point graph // draw Accuracy point graph
const getFinalResult = (final?: MetricDataRecord[]): number => { const getFinalResult = (final?: MetricDataRecord[]): number => {
...@@ -52,12 +56,14 @@ const getFinalResult = (final?: MetricDataRecord[]): number => { ...@@ -52,12 +56,14 @@ const getFinalResult = (final?: MetricDataRecord[]): number => {
let showDefault = 0; let showDefault = 0;
if (final) { if (final) {
acc = parseMetrics(final[final.length - 1].data); acc = parseMetrics(final[final.length - 1].data);
if (typeof (acc) === 'object') { if (typeof (acc) === 'object' && !isArrayType(acc)) {
if (acc.default) { if (acc.default) {
showDefault = acc.default; showDefault = acc.default;
} }
} else { } else if (typeof (acc) === 'number') {
showDefault = acc; showDefault = acc;
} else {
showDefault = NaN;
} }
return showDefault; return showDefault;
} else { } else {
...@@ -72,8 +78,13 @@ const getFinal = (final?: MetricDataRecord[]): FinalType | undefined => { ...@@ -72,8 +78,13 @@ const getFinal = (final?: MetricDataRecord[]): FinalType | undefined => {
showDefault = parseMetrics(final[final.length - 1].data); showDefault = parseMetrics(final[final.length - 1].data);
if (typeof showDefault === 'number') { if (typeof showDefault === 'number') {
showDefault = { default: showDefault }; showDefault = { default: showDefault };
return showDefault;
} else if (isArrayType(showDefault)) {
// not support final type
return undefined;
} else if (typeof showDefault === 'object' && showDefault.hasOwnProperty('default')) {
return showDefault;
} }
return showDefault;
} else { } else {
return undefined; return undefined;
} }
...@@ -205,5 +216,6 @@ function formatAccuracy(accuracy: number): string { ...@@ -205,5 +216,6 @@ function formatAccuracy(accuracy: number): string {
export { export {
convertTime, convertDuration, getFinalResult, getFinal, downFile, convertTime, convertDuration, getFinalResult, getFinal, downFile,
intermediateGraphOption, killJob, filterByStatus, filterDuration, intermediateGraphOption, killJob, filterByStatus, filterDuration,
formatAccuracy, formatTimestamp, metricAccuracy, parseMetrics formatAccuracy, formatTimestamp, metricAccuracy, parseMetrics,
isArrayType
}; };
import { MetricDataRecord, TrialJobInfo, TableObj, TableRecord, Parameters, FinalType } from '../interface'; import { MetricDataRecord, TrialJobInfo, TableObj, TableRecord, Parameters, FinalType } from '../interface';
import { getFinal, formatAccuracy, metricAccuracy, parseMetrics } from '../function'; import { getFinal, formatAccuracy, metricAccuracy, parseMetrics, isArrayType } from '../function';
class Trial implements TableObj { class Trial implements TableObj {
private metricsInitialized: boolean = false; private metricsInitialized: boolean = false;
...@@ -55,9 +55,11 @@ class Trial implements TableObj { ...@@ -55,9 +55,11 @@ class Trial implements TableObj {
} else if (this.intermediates.length > 0) { } else if (this.intermediates.length > 0) {
const temp = this.intermediates[this.intermediates.length - 1]; const temp = this.intermediates[this.intermediates.length - 1];
if (temp !== undefined) { if (temp !== undefined) {
if (typeof parseMetrics(temp.data) === 'object') { if (isArrayType(parseMetrics(temp.data))) {
return undefined;
} else if (typeof parseMetrics(temp.data) === 'object' && parseMetrics(temp.data).hasOwnProperty('default')) {
return parseMetrics(temp.data).default; return parseMetrics(temp.data).default;
} else { } else if (typeof parseMetrics(temp.data) === 'number') {
return parseMetrics(temp.data); return parseMetrics(temp.data);
} }
} else { } else {
...@@ -67,7 +69,6 @@ class Trial implements TableObj { ...@@ -67,7 +69,6 @@ class Trial implements TableObj {
return undefined; return undefined;
} }
} }
/* table obj start */ /* table obj start */
get tableRecord(): TableRecord { get tableRecord(): TableRecord {
......
...@@ -56,4 +56,12 @@ ...@@ -56,4 +56,12 @@
color: #333; color: #333;
font-size: 14px; font-size: 14px;
} }
}
/* overview-succeed-graph */
.showMess{
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
} }
\ No newline at end of file
...@@ -112,7 +112,7 @@ if __name__ == '__main__': ...@@ -112,7 +112,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None) parser.add_argument("--config", type=str, default=None)
parser.add_argument("--exclude", type=str, default=None) parser.add_argument("--exclude", type=str, default=None)
parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'], default='local') parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', 'paiYarn', 'kubeflow', 'frameworkcontroller'], default='local')
parser.add_argument("--local_gpu", action='store_true') parser.add_argument("--local_gpu", action='store_true')
parser.add_argument("--preinstall", action='store_true') parser.add_argument("--preinstall", action='store_true')
args = parser.parse_args() args = parser.parse_args()
......
...@@ -12,7 +12,7 @@ def update_training_service_config(args): ...@@ -12,7 +12,7 @@ def update_training_service_config(args):
config = get_yml_content(TRAINING_SERVICE_FILE) config = get_yml_content(TRAINING_SERVICE_FILE)
if args.nni_manager_ip is not None: if args.nni_manager_ip is not None:
config[args.ts]['nniManagerIp'] = args.nni_manager_ip config[args.ts]['nniManagerIp'] = args.nni_manager_ip
if args.ts == 'pai': if args.ts == 'paiYarn':
if args.pai_user is not None: if args.pai_user is not None:
config[args.ts]['paiYarnConfig']['userName'] = args.pai_user config[args.ts]['paiYarnConfig']['userName'] = args.pai_user
if args.pai_pwd is not None: if args.pai_pwd is not None:
...@@ -27,6 +27,23 @@ def update_training_service_config(args): ...@@ -27,6 +27,23 @@ def update_training_service_config(args):
config[args.ts]['trial']['outputDir'] = args.output_dir config[args.ts]['trial']['outputDir'] = args.output_dir
if args.vc is not None: if args.vc is not None:
config[args.ts]['trial']['virtualCluster'] = args.vc config[args.ts]['trial']['virtualCluster'] = args.vc
if args.ts == 'pai':
if args.pai_user is not None:
config[args.ts]['paiConfig']['userName'] = args.pai_user
if args.pai_host is not None:
config[args.ts]['paiConfig']['host'] = args.pai_host
if args.pai_token is not None:
config[args.ts]['paiConfig']['token'] = args.pai_token
if args.nni_docker_image is not None:
config[args.ts]['trial']['image'] = args.nni_docker_image
if args.nni_manager_nfs_mount_path is not None:
config[args.ts]['trial']['nniManagerNFSMountPath'] = args.nni_manager_nfs_mount_path
if args.container_nfs_mount_path is not None:
config[args.ts]['trial']['containerNFSMountPath'] = args.container_nfs_mount_path
if args.pai_storage_plugin is not None:
config[args.ts]['trial']['paiStoragePlugin'] = args.pai_storage_plugin
if args.vc is not None:
config[args.ts]['trial']['virtualCluster'] = args.vc
elif args.ts == 'kubeflow': elif args.ts == 'kubeflow':
if args.nfs_server is not None: if args.nfs_server is not None:
config[args.ts]['kubeflowConfig']['nfs']['server'] = args.nfs_server config[args.ts]['kubeflowConfig']['nfs']['server'] = args.nfs_server
...@@ -94,6 +111,10 @@ if __name__ == '__main__': ...@@ -94,6 +111,10 @@ if __name__ == '__main__':
parser.add_argument("--data_dir", type=str) parser.add_argument("--data_dir", type=str)
parser.add_argument("--output_dir", type=str) parser.add_argument("--output_dir", type=str)
parser.add_argument("--vc", type=str) parser.add_argument("--vc", type=str)
parser.add_argument("--pai_token", type=str)
parser.add_argument("--pai_storage_plugin", type=str)
parser.add_argument("--nni_manager_nfs_mount_path", type=str)
parser.add_argument("--container_nfs_mount_path", type=str)
# args for kubeflow and frameworkController # args for kubeflow and frameworkController
parser.add_argument("--nfs_server", type=str) parser.add_argument("--nfs_server", type=str)
parser.add_argument("--nfs_path", type=str) parser.add_argument("--nfs_path", type=str)
......
...@@ -51,9 +51,9 @@ jobs: ...@@ -51,9 +51,9 @@ jobs:
echo "TEST_IMG:$TEST_IMG" echo "TEST_IMG:$TEST_IMG"
cd test cd test
python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \ python3 generate_ts_config.py --ts pai --pai_host $(pai_host) --pai_user $(pai_user) --nni_docker_image $TEST_IMG --pai_storage_plugin $(pai_storage_plugin)\
--nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip) --pai_token $(pai_token) --nni_manager_nfs_mount_path $(nni_manager_nfs_mount_path) --container_nfs_mount_path $(container_nfs_mount_path) --nni_manager_ip $(nni_manager_ip)
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts pai --exclude multi_phase
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName: 'integration test' displayName: 'integration test'
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
jobs:
- job: 'integration_test_paiYarn'
timeoutInMinutes: 0
steps:
- script: python3 -m pip install --upgrade pip setuptools --user
displayName: 'Install python tools'
- script: |
cd deployment/pypi
echo 'building prerelease package...'
make build
ls $(Build.SourcesDirectory)/deployment/pypi/dist/
condition: eq( variables['build_docker_img'], 'true' )
displayName: 'build nni bdsit_wheel'
- script: |
source install.sh
displayName: 'Install nni toolkit via source code'
- script: |
sudo apt-get install swig -y
PATH=$HOME/.local/bin:$PATH nnictl package install --name=SMAC
PATH=$HOME/.local/bin:$PATH nnictl package install --name=BOHB
displayName: 'Install dependencies for integration tests in PAI mode'
- script: |
set -e
if [ $(build_docker_img) = 'true' ]
then
cd deployment/pypi
docker login -u $(docker_hub_user) -p $(docker_hub_pwd)
echo 'updating docker file for installing nni from local...'
# update Dockerfile to install NNI in docker image from whl file built in last step
sed -ie 's/RUN python3 -m pip --no-cache-dir install nni/COPY .\/dist\/* .\nRUN python3 -m pip install nni-*.whl/' ../docker/Dockerfile
cat ../docker/Dockerfile
export IMG_TAG=`date -u +%y%m%d%H%M`
echo 'build and upload docker image'
docker build -f ../docker/Dockerfile -t $(test_docker_img_name):$IMG_TAG .
docker push $(test_docker_img_name):$IMG_TAG
export TEST_IMG=$(test_docker_img_name):$IMG_TAG
cd ../../
else
export TEST_IMG=$(existing_docker_img)
fi
echo "TEST_IMG:$TEST_IMG"
cd test
python3 generate_ts_config.py --ts paiYarn --pai_host $(pai_host) --pai_user $(pai_user) --pai_pwd $(pai_pwd) --vc $(pai_virtual_cluster) \
--nni_docker_image $TEST_IMG --data_dir $(data_dir) --output_dir $(output_dir) --nni_manager_ip $(nni_manager_ip)
PATH=$HOME/.local/bin:$PATH python3 config_test.py --ts paiYarn
PATH=$HOME/.local/bin:$PATH python3 metrics_test.py
displayName: 'integration test'
...@@ -52,7 +52,7 @@ frameworkcontroller: ...@@ -52,7 +52,7 @@ frameworkcontroller:
local: local:
trainingServicePlatform: local trainingServicePlatform: local
pai: paiYarn:
nniManagerIp: nniManagerIp:
maxExecDuration: 15m maxExecDuration: 15m
paiYarnConfig: paiYarnConfig:
...@@ -68,6 +68,21 @@ pai: ...@@ -68,6 +68,21 @@ pai:
memoryMB: 8192 memoryMB: 8192
outputDir: outputDir:
virtualCluster: virtualCluster:
pai:
nniManagerIp:
maxExecDuration: 15m
paiConfig:
host:
userName:
trainingServicePlatform: pai
trial:
gpuNum: 1
cpuNum: 1
image:
memoryMB: 8192
nniManagerNFSMountPath:
containerNFSMountPath:
paiStoragePlugin:
remote: remote:
machineList: machineList:
- ip: - ip:
......
...@@ -32,7 +32,8 @@ common_schema = { ...@@ -32,7 +32,8 @@ common_schema = {
'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999), 'trialConcurrency': setNumberRange('trialConcurrency', int, 1, 99999),
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice('trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn'), 'trainingServicePlatform': setChoice(
'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool), Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool), Optional('multiThread'): setType('multiThread', bool),
...@@ -297,6 +298,27 @@ pai_config_schema = { ...@@ -297,6 +298,27 @@ pai_config_schema = {
}) })
} }
dlts_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'image': setType('image', str),
}
}
dlts_config_schema = {
'dltsConfig': {
'dashboard': setType('dashboard', str),
Optional('cluster'): setType('cluster', str),
Optional('team'): setType('team', str),
Optional('email'): setType('email', str),
Optional('password'): setType('password', str),
}
}
kubeflow_trial_schema = { kubeflow_trial_schema = {
'trial':{ 'trial':{
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
...@@ -438,6 +460,8 @@ PAI_CONFIG_SCHEMA = Schema({**common_schema, **pai_trial_schema, **pai_config_sc ...@@ -438,6 +460,8 @@ PAI_CONFIG_SCHEMA = Schema({**common_schema, **pai_trial_schema, **pai_config_sc
PAI_YARN_CONFIG_SCHEMA = Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema}) PAI_YARN_CONFIG_SCHEMA = Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema})
DLTS_CONFIG_SCHEMA = Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema})
KUBEFLOW_CONFIG_SCHEMA = Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema}) KUBEFLOW_CONFIG_SCHEMA = Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema})
FRAMEWORKCONTROLLER_CONFIG_SCHEMA = Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}) FRAMEWORKCONTROLLER_CONFIG_SCHEMA = Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema})
...@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, ...@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
node_command = 'node' node_command = 'node'
if sys.platform == 'win32': if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe') node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform] cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform]
if mode == 'view': if mode == 'view':
cmds += ['--start_mode', 'resume'] cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true'] cmds += ['--readonly', 'true']
...@@ -289,6 +289,25 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name): ...@@ -289,6 +289,25 @@ def set_frameworkcontroller_config(experiment_config, port, config_file_name):
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
def set_dlts_config(experiment_config, port, config_file_name):
'''set dlts configuration'''
dlts_config_data = dict()
dlts_config_data['dlts_config'] = experiment_config['dltsConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(dlts_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_experiment(experiment_config, mode, port, config_file_name): def set_experiment(experiment_config, mode, port, config_file_name):
'''Call startExperiment (rest POST /experiment) with yaml file content''' '''Call startExperiment (rest POST /experiment) with yaml file content'''
request_data = dict() request_data = dict()
...@@ -389,6 +408,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -389,6 +408,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name) config_result, err_msg = set_kubeflow_config(experiment_config, port, config_file_name)
elif platform == 'frameworkcontroller': elif platform == 'frameworkcontroller':
config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name) config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
elif platform == 'dlts':
config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name)
else: else:
raise Exception(ERROR_INFO % 'Unsupported platform!') raise Exception(ERROR_INFO % 'Unsupported platform!')
exit(1) exit(1)
...@@ -525,7 +546,15 @@ def create_experiment(args): ...@@ -525,7 +546,15 @@ def create_experiment(args):
nni_config.set_config('experimentConfig', experiment_config) nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port) nni_config.set_config('restServerPort', args.port)
launch_experiment(args, experiment_config, 'new', config_file_name) try:
launch_experiment(args, experiment_config, 'new', config_file_name)
except Exception as exception:
nni_config = Config(config_file_name)
restServerPid = nni_config.get_config('restServerPid')
if restServerPid:
kill_command(restServerPid)
print_error(exception)
exit(1)
def manage_stopped_experiment(args, mode): def manage_stopped_experiment(args, mode):
'''view a stopped experiment''' '''view a stopped experiment'''
...@@ -553,8 +582,16 @@ def manage_stopped_experiment(args, mode): ...@@ -553,8 +582,16 @@ def manage_stopped_experiment(args, mode):
new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8)) new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
new_nni_config = Config(new_config_file_name) new_nni_config = Config(new_config_file_name)
new_nni_config.set_config('experimentConfig', experiment_config) new_nni_config.set_config('experimentConfig', experiment_config)
launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
new_nni_config.set_config('restServerPort', args.port) new_nni_config.set_config('restServerPort', args.port)
try:
launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
except Exception as exception:
nni_config = Config(new_config_file_name)
restServerPid = nni_config.get_config('restServerPid')
if restServerPid:
kill_command(restServerPid)
print_error(exception)
exit(1)
def view_experiment(args): def view_experiment(args):
'''view a stopped experiment''' '''view a stopped experiment'''
......
...@@ -5,8 +5,9 @@ import os ...@@ -5,8 +5,9 @@ import os
import json import json
from schema import SchemaError from schema import SchemaError
from schema import Schema from schema import Schema
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\ from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, \
FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict DLTS_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA, FRAMEWORKCONTROLLER_CONFIG_SCHEMA, \
tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from .common_utils import print_error, print_warning, print_normal, get_yml_content from .common_utils import print_error, print_warning, print_normal, get_yml_content
def expand_path(experiment_config, key): def expand_path(experiment_config, key):
...@@ -147,7 +148,9 @@ def validate_kubeflow_operators(experiment_config): ...@@ -147,7 +148,9 @@ def validate_kubeflow_operators(experiment_config):
def validate_common_content(experiment_config): def validate_common_content(experiment_config):
'''Validate whether the common values in experiment_config is valid''' '''Validate whether the common values in experiment_config is valid'''
if not experiment_config.get('trainingServicePlatform') or \ if not experiment_config.get('trainingServicePlatform') or \
experiment_config.get('trainingServicePlatform') not in ['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn']: experiment_config.get('trainingServicePlatform') not in [
'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'
]:
print_error('Please set correct trainingServicePlatform!') print_error('Please set correct trainingServicePlatform!')
exit(1) exit(1)
schema_dict = { schema_dict = {
...@@ -156,7 +159,8 @@ def validate_common_content(experiment_config): ...@@ -156,7 +159,8 @@ def validate_common_content(experiment_config):
'pai': PAI_CONFIG_SCHEMA, 'pai': PAI_CONFIG_SCHEMA,
'paiYarn': PAI_YARN_CONFIG_SCHEMA, 'paiYarn': PAI_YARN_CONFIG_SCHEMA,
'kubeflow': KUBEFLOW_CONFIG_SCHEMA, 'kubeflow': KUBEFLOW_CONFIG_SCHEMA,
'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA 'frameworkcontroller': FRAMEWORKCONTROLLER_CONFIG_SCHEMA,
'dlts': DLTS_CONFIG_SCHEMA,
} }
separate_schema_dict = { separate_schema_dict = {
'tuner': tuner_schema_dict, 'tuner': tuner_schema_dict,
......
...@@ -11,7 +11,7 @@ from .updater import update_searchspace, update_concurrency, update_duration, up ...@@ -11,7 +11,7 @@ from .updater import update_searchspace, update_concurrency, update_duration, up
from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment, experiment_status,\ from .nnictl_utils import stop_experiment, trial_ls, trial_kill, list_experiment, experiment_status,\
log_trial, experiment_clean, platform_clean, experiment_list, \ log_trial, experiment_clean, platform_clean, experiment_list, \
monitor_experiment, export_trials_data, trial_codegen, webui_url, \ monitor_experiment, export_trials_data, trial_codegen, webui_url, \
get_config, log_stdout, log_stderr, search_space_auto_gen get_config, log_stdout, log_stderr, search_space_auto_gen, webui_nas
from .package_management import package_install, package_show from .package_management import package_install, package_show
from .constants import DEFAULT_REST_PORT from .constants import DEFAULT_REST_PORT
from .tensorboard_utils import start_tensorboard, stop_tensorboard from .tensorboard_utils import start_tensorboard, stop_tensorboard
...@@ -158,6 +158,10 @@ def parse_args(): ...@@ -158,6 +158,10 @@ def parse_args():
parser_webui_url = parser_webui_subparsers.add_parser('url', help='show the url of web ui') parser_webui_url = parser_webui_subparsers.add_parser('url', help='show the url of web ui')
parser_webui_url.add_argument('id', nargs='?', help='the id of experiment') parser_webui_url.add_argument('id', nargs='?', help='the id of experiment')
parser_webui_url.set_defaults(func=webui_url) parser_webui_url.set_defaults(func=webui_url)
parser_webui_nas = parser_webui_subparsers.add_parser('nas', help='show nas ui')
parser_webui_nas.add_argument('--port', default=6060, type=int, help='port of nas ui')
parser_webui_nas.add_argument('--logdir', default='.', type=str, help='the logdir where nas ui will read data')
parser_webui_nas.set_defaults(func=webui_nas)
#parse config command #parse config command
parser_config = subparsers.add_parser('config', help='get config information') parser_config = subparsers.add_parser('config', help='get config information')
......
...@@ -8,6 +8,7 @@ import json ...@@ -8,6 +8,7 @@ import json
import time import time
import re import re
import shutil import shutil
import subprocess
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from subprocess import Popen from subprocess import Popen
...@@ -388,6 +389,17 @@ def webui_url(args): ...@@ -388,6 +389,17 @@ def webui_url(args):
nni_config = Config(get_config_filename(args)) nni_config = Config(get_config_filename(args))
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl')))) print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))
def webui_nas(args):
'''launch nas ui'''
print_normal('Starting NAS UI...')
# TODO: find file path on installing with pypi
# TODO: use correct node on win32
try:
cmds = ['node', 'src/nasui/server.js', '--port', str(args.port), '--logdir', args.logdir]
subprocess.run(cmds)
except KeyboardInterrupt:
pass
def local_clean(directory): def local_clean(directory):
'''clean up local data''' '''clean up local data'''
print_normal('removing folder {0}'.format(directory)) print_normal('removing folder {0}'.format(directory))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment