Unverified Commit 3fe117f0 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #231 from microsoft/master

merge master
parents 129c4a53 74250987
......@@ -48,12 +48,12 @@ Note: You should set `trainingServicePlatform: pai` in NNI config YAML file if y
Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMode.md), trial configuration in pai mode have these additional keys:
* cpuNum
* Required key. Should be positive number based on your trial program's CPU requirement
* Optional key. Should be positive number based on your trial program's CPU requirement. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* memoryMB
* Required key. Should be positive number based on your trial program's memory requirement
* Optional key. Should be positive number based on your trial program's memory requirement. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* image
* Required key. In pai mode, your trial program will be scheduled by OpenPAI to run in [Docker container](https://www.docker.com/). This key is used to specify the Docker image used to create the container in which your trial will run.
* We already build a docker image [nnimsra/nni](https://hub.docker.com/r/msranni/nni/) on [Docker Hub](https://hub.docker.com/). It contains NNI python packages, Node modules and javascript artifact files required to start experiment, and all of NNI dependencies. The docker file used to build this image can be found at [here](https://github.com/Microsoft/nni/tree/master/deployment/docker/Dockerfile). You can either use this image directly in your config file, or build your own image based on it.
* Optional key. In pai mode, your trial program will be scheduled by OpenPAI to run in [Docker container](https://www.docker.com/). This key is used to specify the Docker image used to create the container in which your trial will run.
* We already build a docker image [nnimsra/nni](https://hub.docker.com/r/msranni/nni/) on [Docker Hub](https://hub.docker.com/). It contains NNI python packages, Node modules and javascript artifact files required to start experiment, and all of NNI dependencies. The docker file used to build this image can be found at [here](https://github.com/Microsoft/nni/tree/master/deployment/docker/Dockerfile). You can either use this image directly in your config file, or build your own image based on it. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* virtualCluster
* Optional key. Set the virtualCluster of OpenPAI. If omitted, the job will run on default virtual cluster.
* nniManagerNFSMountPath
......@@ -61,7 +61,9 @@ Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMod
* containerNFSMountPath
* Required key. Set the mount path in your container used in PAI.
* paiStoragePlugin
* Required key. Set the storage plugin name used in PAI.
* Optional key. Set the storage plugin name used in PAI. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* paiConfigPath
* Optional key. Set the file path of pai job configuration, the file is in yaml format.
Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command
......
......@@ -49,7 +49,7 @@ nnictl support commands:
|--config, -c| True| |YAML configure file of the experiment|
|--port, -p|False| |the port of restful server|
|--debug, -d|False||set debug mode|
|--watch, -w|False||set watch mode|
|--foreground, -f|False||set foreground mode, print log content to terminal|
* Examples
......@@ -98,7 +98,7 @@ Debug mode will disable version check function in Trialkeeper.
|id| True| |The id of the experiment you want to resume|
|--port, -p| False| |Rest port of the experiment you want to resume|
|--debug, -d|False||set debug mode|
|--watch, -w|False||set watch mode|
|--foreground, -f|False||set foreground mode, print log content to terminal|
* Example
......
......@@ -4,13 +4,11 @@
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream';
import { WritableStreamBuffer } from 'stream-buffers';
import { format } from 'util';
import * as component from '../common/component';
import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo';
import { getLogDir } from './utils';
const FATAL: number = 1;
const ERROR: number = 2;
......@@ -55,23 +53,21 @@ class BufferSerialEmitter {
@component.Singleton
class Logger {
private DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
private level: number = INFO;
private bufferSerialEmitter: BufferSerialEmitter;
private writable: Writable;
private bufferSerialEmitter?: BufferSerialEmitter;
private writable?: Writable;
private readonly: boolean = false;
constructor(fileName?: string) {
let logFile: string | undefined = fileName;
if (logFile === undefined) {
logFile = this.DEFAULT_LOGFILE;
}
const logFile: string | undefined = fileName;
if (logFile) {
this.writable = fs.createWriteStream(logFile, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
this.bufferSerialEmitter = new BufferSerialEmitter(this.writable);
}
const logLevelName: string = getExperimentStartupInfo()
.getLogLevel();
......@@ -84,8 +80,10 @@ class Logger {
}
public close(): void {
if (this.writable) {
this.writable.destroy();
}
}
public trace(...param: any[]): void {
if (this.level >= TRACE) {
......@@ -128,12 +126,15 @@ class Logger {
*/
private log(level: string, param: any[]): void {
if (!this.readonly) {
const logContent = `[${(new Date()).toLocaleString()}] ${level} ${format(param)}\n`;
if (this.writable && this.bufferSerialEmitter) {
const buffer: WritableStreamBuffer = new WritableStreamBuffer();
buffer.write(`[${(new Date()).toLocaleString()}] ${level} `);
buffer.write(format(param));
buffer.write('\n');
buffer.write(logContent);
buffer.end();
this.bufferSerialEmitter.feed(buffer.getContents());
} else {
console.log(logContent);
}
}
}
}
......
......@@ -6,6 +6,7 @@
import { Container, Scope } from 'typescript-ioc';
import * as fs from 'fs';
import * as path from 'path';
import * as component from './common/component';
import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo';
......@@ -34,7 +35,7 @@ function initStartupInfo(
setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly);
}
async function initContainer(platformMode: string, logFileName?: string): Promise<void> {
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
if (platformMode === 'local') {
Container.bind(TrainingService)
.to(LocalTrainingService)
......@@ -71,6 +72,12 @@ async function initContainer(platformMode: string, logFileName?: string): Promis
Container.bind(DataStore)
.to(NNIDataStore)
.scope(Scope.Singleton);
const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
if (foreground) {
logFileName = undefined;
} else if (logFileName === undefined) {
logFileName = DEFAULT_LOGFILE;
}
Container.bind(Logger).provider({
get: (): Logger => new Logger(logFileName)
});
......@@ -81,7 +88,7 @@ async function initContainer(platformMode: string, logFileName?: string): Promis
function usage(): void {
console.info('usage: node main.js --port <port> --mode \
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id>');
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
}
const strPort: string = parseArg(['--port', '-p']);
......@@ -90,6 +97,14 @@ if (!strPort || strPort.length === 0) {
process.exit(1);
}
const foregroundArg: string = parseArg(['--foreground', '-f']);
if (!('true' || 'false').includes(foregroundArg.toLowerCase())) {
console.log(`FATAL: foreground property should only be true or false`);
usage();
process.exit(1);
}
const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : false;
const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']);
......@@ -138,7 +153,7 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly);
mkDirP(getLogDir())
.then(async () => {
try {
await initContainer(mode);
await initContainer(foreground, mode);
const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start();
const log: Logger = getLogger();
......@@ -162,6 +177,15 @@ function getStopSignal(): any {
}
}
function getCtrlCSignal(): any {
return 'SIGINT';
}
process.on(getCtrlCSignal(), async () => {
const log: Logger = getLogger();
log.info(`Get SIGINT signal!`);
});
process.on(getStopSignal(), async () => {
const log: Logger = getLogger();
let hasError: boolean = false;
......
......@@ -13,6 +13,7 @@
"azure-storage": "^2.10.2",
"chai-as-promised": "^7.1.1",
"child-process-promise": "^2.2.1",
"deepmerge": "^4.2.2",
"express": "^4.16.3",
"express-joi-validator": "^2.0.0",
"js-base64": "^2.4.9",
......
......@@ -38,6 +38,7 @@ export namespace ValidationSchemas {
authFile: joi.string(),
nniManagerNFSMountPath: joi.string().min(1),
containerNFSMountPath: joi.string().min(1),
paiConfigPath: joi.string(),
paiStoragePlugin: joi.string().min(1),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
portList: joi.array().items(joi.object({
......
......@@ -31,10 +31,11 @@ export class NNIPAIK8STrialConfig extends TrialConfig {
public readonly nniManagerNFSMountPath: string;
public readonly containerNFSMountPath: string;
public readonly paiStoragePlugin: string;
public readonly paiConfigPath?: string;
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, nniManagerNFSMountPath: string, containerNFSMountPath: string,
paiStoragePlugin: string, virtualCluster?: string) {
paiStoragePlugin: string, virtualCluster?: string, paiConfigPath?: string) {
super(command, codeDir, gpuNum);
this.cpuNum = cpuNum;
this.memoryMB = memoryMB;
......@@ -43,5 +44,6 @@ export class NNIPAIK8STrialConfig extends TrialConfig {
this.nniManagerNFSMountPath = nniManagerNFSMountPath;
this.containerNFSMountPath = containerNFSMountPath;
this.paiStoragePlugin = paiStoragePlugin;
this.paiConfigPath = paiConfigPath;
}
}
......@@ -44,6 +44,7 @@ import { PAIClusterConfig, PAITrialJobDetail } from '../paiConfig';
import { PAIJobRestServer } from '../paiJobRestServer';
const yaml = require('js-yaml');
const deepmerge = require('deepmerge');
/**
* Training Service implementation for OpenPAI (Open Platform for AI)
......@@ -189,8 +190,20 @@ class PAIK8STrainingService extends PAITrainingService {
}
}
if (this.paiTrialConfig.paiConfigPath) {
try {
const additionalPAIConfig = yaml.safeLoad(fs.readFileSync(this.paiTrialConfig.paiConfigPath, 'utf8'));
//deepmerge(x, y), if an element at the same key is present for both x and y, the value from y will appear in the result.
//refer: https://github.com/TehShrike/deepmerge
const overwriteMerge = (destinationArray: any, sourceArray: any, options: any) => sourceArray;
return yaml.safeDump(deepmerge(additionalPAIConfig, paiJobConfig, { arrayMerge: overwriteMerge }));
} catch (error) {
this.log.error(`Error occurs during loading and merge ${this.paiTrialConfig.paiConfigPath} : ${error}`);
}
} else {
return yaml.safeDump(paiJobConfig);
}
}
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
const deferred: Deferred<boolean> = new Deferred<boolean>();
......@@ -258,7 +271,7 @@ class PAIK8STrainingService extends PAITrainingService {
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiJobConfig = this.generateJobConfigInYamlFormat(trialJobId, nniPaiTrialCommand);
this.log.debug(paiJobConfig);
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
......
......@@ -1112,6 +1112,11 @@ deepmerge@^2.1.1:
version "2.2.1"
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-2.2.1.tgz#5d3ff22a01c00f645405a2fbc17d0778a1801170"
deepmerge@^4.2.2:
version "4.2.2"
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955"
integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg==
default-require-extensions@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/default-require-extensions/-/default-require-extensions-2.0.0.tgz#f5f8fbb18a7d6d50b21f641f649ebb522cfe24f7"
......
......@@ -113,7 +113,7 @@ class AGP_Pruner(Pruner):
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask
# if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask
w_abs = weight.abs() * mask['weight']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask})
......
......@@ -271,16 +271,17 @@ pai_yarn_config_schema = {
pai_trial_schema = {
'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
'containerNFSMountPath': setType('containerNFSMountPath', str),
'paiStoragePlugin': setType('paiStoragePlugin', str)
'command': setType('command', str),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
Optional('memoryMB'): setType('memoryMB', int),
Optional('image'): setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
Optional('paiStoragePlugin'): setType('paiStoragePlugin', str),
Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
}
}
......
......@@ -9,7 +9,7 @@ import random
import site
import time
import tempfile
from subprocess import Popen, check_call, CalledProcessError
from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT
from nni_annotation import expand_annotations, generate_search_space
from nni.constants import ModuleName, AdvisorModuleName
from .launcher_utils import validate_all_content
......@@ -20,7 +20,7 @@ from .common_utils import get_yml_content, get_json_content, print_error, print_
detect_port, get_user, get_python_dir
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS
from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment, set_monitor
from .nnictl_utils import update_experiment
def get_log_path(config_file_name):
'''generate stdout and stderr log path'''
......@@ -78,17 +78,17 @@ def get_nni_installation_path():
print_error('Fail to find nni under python library')
exit(1)
def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
def start_rest_server(args, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
'''Run nni manager process'''
if detect_port(port):
if detect_port(args.port):
print_error('Port %s is used by another process, please reset the port!\n' \
'You could use \'nnictl create --help\' to get help information' % port)
'You could use \'nnictl create --help\' to get help information' % args.port)
exit(1)
if (platform != 'local') and detect_port(int(port) + 1):
if (platform != 'local') and detect_port(int(args.port) + 1):
print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
'You could set another port to start experiment!\n' \
'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1)))
'You could use \'nnictl create --help\' to get help information' % ((int(args.port) + 1), (int(args.port) + 1)))
exit(1)
print_normal('Starting restful server...')
......@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
node_command = 'node'
if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform]
cmds = [node_command, entry_file, '--port', str(args.port), '--mode', platform]
if mode == 'view':
cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true']
......@@ -111,6 +111,8 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
cmds += ['--log_level', log_level]
if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id]
if args.foreground:
cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
......@@ -120,7 +122,13 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
stderr_file.write(log_header)
if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP
if args.foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=STDOUT, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
if args.foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now)
......@@ -424,7 +432,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug'
# start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
rest_process, start_time = start_rest_server(args, experiment_config['trainingServicePlatform'], \
mode, config_file_name, experiment_id, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
......@@ -493,8 +501,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
experiment_config['experimentName'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if args.watch:
set_monitor(True, 3, args.port, rest_process.pid)
if args.foreground:
try:
while True:
log_content = rest_process.stdout.readline().strip().decode('utf-8')
print(log_content)
except KeyboardInterrupt:
kill_command(rest_process.pid)
print_normal('Stopping experiment...')
def create_experiment(args):
'''start a new experiment'''
......
......@@ -7,7 +7,7 @@ from schema import SchemaError
from schema import Schema
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from .common_utils import print_error, print_warning, print_normal
from .common_utils import print_error, print_warning, print_normal, get_yml_content
def expand_path(experiment_config, key):
'''Change '~' to user home directory'''
......@@ -63,6 +63,8 @@ def parse_path(experiment_config, config_path):
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
expand_path(experiment_config['machineList'][index], 'sshKeyPath')
if experiment_config['trial'].get('paiConfigPath'):
expand_path(experiment_config['trial'], 'paiConfigPath')
#if users use relative path, convert it to absolute path
root_path = os.path.dirname(config_path)
......@@ -94,6 +96,8 @@ def parse_path(experiment_config, config_path):
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
if experiment_config['trial'].get('paiConfigPath'):
parse_relative_path(root_path, experiment_config['trial'], 'paiConfigPath')
def validate_search_space_content(experiment_config):
'''Validate searchspace content,
......@@ -254,6 +258,45 @@ def validate_machine_list(experiment_config):
print_error('Please set machineList!')
exit(1)
def validate_pai_config_path(experiment_config):
'''validate paiConfigPath field'''
if experiment_config.get('trainingServicePlatform') == 'pai':
if experiment_config.get('trial', {}).get('paiConfigPath'):
# validate the file format of paiConfigPath, ensure it is yaml format
pai_config = get_yml_content(experiment_config['trial']['paiConfigPath'])
if experiment_config['trial'].get('image') is None:
if pai_config.get('prerequisites', [{}])[0].get('uri') is None:
print_error('Please set image field, or set image uri in your own paiConfig!')
exit(1)
experiment_config['trial']['image'] = pai_config['prerequisites'][0]['uri']
if experiment_config['trial'].get('gpuNum') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('gpu') is None:
print_error('Please set gpuNum field, or set resourcePerInstance gpu in your own paiConfig!')
exit(1)
experiment_config['trial']['gpuNum'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['gpu']
if experiment_config['trial'].get('cpuNum') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('cpu') is None:
print_error('Please set cpuNum field, or set resourcePerInstance cpu in your own paiConfig!')
exit(1)
experiment_config['trial']['cpuNum'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['cpu']
if experiment_config['trial'].get('memoryMB') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('memoryMB', {}) is None:
print_error('Please set memoryMB field, or set resourcePerInstance memoryMB in your own paiConfig!')
exit(1)
experiment_config['trial']['memoryMB'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['memoryMB']
if experiment_config['trial'].get('paiStoragePlugin') is None:
if pai_config.get('extras', {}).get('com.microsoft.pai.runtimeplugin', [{}])[0].get('plugin') is None:
print_error('Please set paiStoragePlugin field, or set plugin in your own paiConfig!')
exit(1)
experiment_config['trial']['paiStoragePlugin'] = pai_config['extras']['com.microsoft.pai.runtimeplugin'][0]['plugin']
else:
pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStoragePlugin']
for trial_field in pai_trial_fields_required_list:
if experiment_config['trial'].get(trial_field) is None:
print_error('Please set {0} in trial configuration,\
or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
exit(1)
def validate_pai_trial_conifg(experiment_config):
'''validate the trial config in pai platform'''
if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
......@@ -269,6 +312,7 @@ def validate_pai_trial_conifg(experiment_config):
print_warning(warning_information.format('dataDir'))
if experiment_config.get('trial').get('outputDir'):
print_warning(warning_information.format('outputDir'))
validate_pai_config_path(experiment_config)
def validate_all_content(experiment_config, config_path):
'''Validate whether experiment_config is valid'''
......
......@@ -51,7 +51,7 @@ def parse_args():
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_start.add_argument('--watch', '-w', action='store_true', help=' set watch mode')
parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_start.set_defaults(func=create_experiment)
# parse resume command
......@@ -59,7 +59,7 @@ def parse_args():
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to resume')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_resume.add_argument('--watch', '-w', action='store_true', help=' set watch mode')
parser_resume.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_resume.set_defaults(func=resume_experiment)
# parse view command
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment