Unverified Commit 3fe117f0 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #231 from microsoft/master

merge master
parents 129c4a53 74250987
...@@ -48,12 +48,12 @@ Note: You should set `trainingServicePlatform: pai` in NNI config YAML file if y ...@@ -48,12 +48,12 @@ Note: You should set `trainingServicePlatform: pai` in NNI config YAML file if y
Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMode.md), trial configuration in pai mode have these additional keys: Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMode.md), trial configuration in pai mode have these additional keys:
* cpuNum * cpuNum
* Required key. Should be positive number based on your trial program's CPU requirement * Optional key. Should be positive number based on your trial program's CPU requirement. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* memoryMB * memoryMB
* Required key. Should be positive number based on your trial program's memory requirement * Optional key. Should be positive number based on your trial program's memory requirement. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* image * image
* Required key. In pai mode, your trial program will be scheduled by OpenPAI to run in [Docker container](https://www.docker.com/). This key is used to specify the Docker image used to create the container in which your trial will run. * Optional key. In pai mode, your trial program will be scheduled by OpenPAI to run in [Docker container](https://www.docker.com/). This key is used to specify the Docker image used to create the container in which your trial will run.
* We already build a docker image [nnimsra/nni](https://hub.docker.com/r/msranni/nni/) on [Docker Hub](https://hub.docker.com/). It contains NNI python packages, Node modules and javascript artifact files required to start experiment, and all of NNI dependencies. The docker file used to build this image can be found at [here](https://github.com/Microsoft/nni/tree/master/deployment/docker/Dockerfile). You can either use this image directly in your config file, or build your own image based on it. * We already build a docker image [nnimsra/nni](https://hub.docker.com/r/msranni/nni/) on [Docker Hub](https://hub.docker.com/). It contains NNI python packages, Node modules and javascript artifact files required to start experiment, and all of NNI dependencies. The docker file used to build this image can be found at [here](https://github.com/Microsoft/nni/tree/master/deployment/docker/Dockerfile). You can either use this image directly in your config file, or build your own image based on it. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* virtualCluster * virtualCluster
* Optional key. Set the virtualCluster of OpenPAI. If omitted, the job will run on default virtual cluster. * Optional key. Set the virtualCluster of OpenPAI. If omitted, the job will run on default virtual cluster.
* nniManagerNFSMountPath * nniManagerNFSMountPath
...@@ -61,7 +61,9 @@ Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMod ...@@ -61,7 +61,9 @@ Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMod
* containerNFSMountPath * containerNFSMountPath
* Required key. Set the mount path in your container used in PAI. * Required key. Set the mount path in your container used in PAI.
* paiStoragePlugin * paiStoragePlugin
* Required key. Set the storage plugin name used in PAI. * Optional key. Set the storage plugin name used in PAI. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* paiConfigPath
* Optional key. Set the file path of pai job configuration, the file is in yaml format.
Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command
......
...@@ -49,7 +49,7 @@ nnictl support commands: ...@@ -49,7 +49,7 @@ nnictl support commands:
|--config, -c| True| |YAML configure file of the experiment| |--config, -c| True| |YAML configure file of the experiment|
|--port, -p|False| |the port of restful server| |--port, -p|False| |the port of restful server|
|--debug, -d|False||set debug mode| |--debug, -d|False||set debug mode|
|--watch, -w|False||set watch mode| |--foreground, -f|False||set foreground mode, print log content to terminal|
* Examples * Examples
...@@ -98,7 +98,7 @@ Debug mode will disable version check function in Trialkeeper. ...@@ -98,7 +98,7 @@ Debug mode will disable version check function in Trialkeeper.
|id| True| |The id of the experiment you want to resume| |id| True| |The id of the experiment you want to resume|
|--port, -p| False| |Rest port of the experiment you want to resume| |--port, -p| False| |Rest port of the experiment you want to resume|
|--debug, -d|False||set debug mode| |--debug, -d|False||set debug mode|
|--watch, -w|False||set watch mode| |--foreground, -f|False||set foreground mode, print log content to terminal|
* Example * Example
......
...@@ -4,13 +4,11 @@ ...@@ -4,13 +4,11 @@
'use strict'; 'use strict';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream'; import { Writable } from 'stream';
import { WritableStreamBuffer } from 'stream-buffers'; import { WritableStreamBuffer } from 'stream-buffers';
import { format } from 'util'; import { format } from 'util';
import * as component from '../common/component'; import * as component from '../common/component';
import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo'; import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo';
import { getLogDir } from './utils';
const FATAL: number = 1; const FATAL: number = 1;
const ERROR: number = 2; const ERROR: number = 2;
...@@ -55,23 +53,21 @@ class BufferSerialEmitter { ...@@ -55,23 +53,21 @@ class BufferSerialEmitter {
@component.Singleton @component.Singleton
class Logger { class Logger {
private DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
private level: number = INFO; private level: number = INFO;
private bufferSerialEmitter: BufferSerialEmitter; private bufferSerialEmitter?: BufferSerialEmitter;
private writable: Writable; private writable?: Writable;
private readonly: boolean = false; private readonly: boolean = false;
constructor(fileName?: string) { constructor(fileName?: string) {
let logFile: string | undefined = fileName; const logFile: string | undefined = fileName;
if (logFile === undefined) { if (logFile) {
logFile = this.DEFAULT_LOGFILE; this.writable = fs.createWriteStream(logFile, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
this.bufferSerialEmitter = new BufferSerialEmitter(this.writable);
} }
this.writable = fs.createWriteStream(logFile, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
});
this.bufferSerialEmitter = new BufferSerialEmitter(this.writable);
const logLevelName: string = getExperimentStartupInfo() const logLevelName: string = getExperimentStartupInfo()
.getLogLevel(); .getLogLevel();
...@@ -84,7 +80,9 @@ class Logger { ...@@ -84,7 +80,9 @@ class Logger {
} }
public close(): void { public close(): void {
this.writable.destroy(); if (this.writable) {
this.writable.destroy();
}
} }
public trace(...param: any[]): void { public trace(...param: any[]): void {
...@@ -128,12 +126,15 @@ class Logger { ...@@ -128,12 +126,15 @@ class Logger {
*/ */
private log(level: string, param: any[]): void { private log(level: string, param: any[]): void {
if (!this.readonly) { if (!this.readonly) {
const buffer: WritableStreamBuffer = new WritableStreamBuffer(); const logContent = `[${(new Date()).toLocaleString()}] ${level} ${format(param)}\n`;
buffer.write(`[${(new Date()).toLocaleString()}] ${level} `); if (this.writable && this.bufferSerialEmitter) {
buffer.write(format(param)); const buffer: WritableStreamBuffer = new WritableStreamBuffer();
buffer.write('\n'); buffer.write(logContent);
buffer.end(); buffer.end();
this.bufferSerialEmitter.feed(buffer.getContents()); this.bufferSerialEmitter.feed(buffer.getContents());
} else {
console.log(logContent);
}
} }
} }
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import { Container, Scope } from 'typescript-ioc'; import { Container, Scope } from 'typescript-ioc';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path';
import * as component from './common/component'; import * as component from './common/component';
import { Database, DataStore } from './common/datastore'; import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo'; import { setExperimentStartupInfo } from './common/experimentStartupInfo';
...@@ -34,7 +35,7 @@ function initStartupInfo( ...@@ -34,7 +35,7 @@ function initStartupInfo(
setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly); setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly);
} }
async function initContainer(platformMode: string, logFileName?: string): Promise<void> { async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
if (platformMode === 'local') { if (platformMode === 'local') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(LocalTrainingService) .to(LocalTrainingService)
...@@ -71,6 +72,12 @@ async function initContainer(platformMode: string, logFileName?: string): Promis ...@@ -71,6 +72,12 @@ async function initContainer(platformMode: string, logFileName?: string): Promis
Container.bind(DataStore) Container.bind(DataStore)
.to(NNIDataStore) .to(NNIDataStore)
.scope(Scope.Singleton); .scope(Scope.Singleton);
const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
if (foreground) {
logFileName = undefined;
} else if (logFileName === undefined) {
logFileName = DEFAULT_LOGFILE;
}
Container.bind(Logger).provider({ Container.bind(Logger).provider({
get: (): Logger => new Logger(logFileName) get: (): Logger => new Logger(logFileName)
}); });
...@@ -81,7 +88,7 @@ async function initContainer(platformMode: string, logFileName?: string): Promis ...@@ -81,7 +88,7 @@ async function initContainer(platformMode: string, logFileName?: string): Promis
function usage(): void { function usage(): void {
console.info('usage: node main.js --port <port> --mode \ console.info('usage: node main.js --port <port> --mode \
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id>'); <local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
} }
const strPort: string = parseArg(['--port', '-p']); const strPort: string = parseArg(['--port', '-p']);
...@@ -90,6 +97,14 @@ if (!strPort || strPort.length === 0) { ...@@ -90,6 +97,14 @@ if (!strPort || strPort.length === 0) {
process.exit(1); process.exit(1);
} }
const foregroundArg: string = parseArg(['--foreground', '-f']);
if (!('true' || 'false').includes(foregroundArg.toLowerCase())) {
console.log(`FATAL: foreground property should only be true or false`);
usage();
process.exit(1);
}
const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : false;
const port: number = parseInt(strPort, 10); const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']); const mode: string = parseArg(['--mode', '-m']);
...@@ -138,7 +153,7 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly); ...@@ -138,7 +153,7 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly);
mkDirP(getLogDir()) mkDirP(getLogDir())
.then(async () => { .then(async () => {
try { try {
await initContainer(mode); await initContainer(foreground, mode);
const restServer: NNIRestServer = component.get(NNIRestServer); const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start(); await restServer.start();
const log: Logger = getLogger(); const log: Logger = getLogger();
...@@ -162,6 +177,15 @@ function getStopSignal(): any { ...@@ -162,6 +177,15 @@ function getStopSignal(): any {
} }
} }
function getCtrlCSignal(): any {
return 'SIGINT';
}
process.on(getCtrlCSignal(), async () => {
const log: Logger = getLogger();
log.info(`Get SIGINT signal!`);
});
process.on(getStopSignal(), async () => { process.on(getStopSignal(), async () => {
const log: Logger = getLogger(); const log: Logger = getLogger();
let hasError: boolean = false; let hasError: boolean = false;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
"azure-storage": "^2.10.2", "azure-storage": "^2.10.2",
"chai-as-promised": "^7.1.1", "chai-as-promised": "^7.1.1",
"child-process-promise": "^2.2.1", "child-process-promise": "^2.2.1",
"deepmerge": "^4.2.2",
"express": "^4.16.3", "express": "^4.16.3",
"express-joi-validator": "^2.0.0", "express-joi-validator": "^2.0.0",
"js-base64": "^2.4.9", "js-base64": "^2.4.9",
......
...@@ -38,6 +38,7 @@ export namespace ValidationSchemas { ...@@ -38,6 +38,7 @@ export namespace ValidationSchemas {
authFile: joi.string(), authFile: joi.string(),
nniManagerNFSMountPath: joi.string().min(1), nniManagerNFSMountPath: joi.string().min(1),
containerNFSMountPath: joi.string().min(1), containerNFSMountPath: joi.string().min(1),
paiConfigPath: joi.string(),
paiStoragePlugin: joi.string().min(1), paiStoragePlugin: joi.string().min(1),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
portList: joi.array().items(joi.object({ portList: joi.array().items(joi.object({
......
...@@ -31,10 +31,11 @@ export class NNIPAIK8STrialConfig extends TrialConfig { ...@@ -31,10 +31,11 @@ export class NNIPAIK8STrialConfig extends TrialConfig {
public readonly nniManagerNFSMountPath: string; public readonly nniManagerNFSMountPath: string;
public readonly containerNFSMountPath: string; public readonly containerNFSMountPath: string;
public readonly paiStoragePlugin: string; public readonly paiStoragePlugin: string;
public readonly paiConfigPath?: string;
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number, constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, nniManagerNFSMountPath: string, containerNFSMountPath: string, image: string, nniManagerNFSMountPath: string, containerNFSMountPath: string,
paiStoragePlugin: string, virtualCluster?: string) { paiStoragePlugin: string, virtualCluster?: string, paiConfigPath?: string) {
super(command, codeDir, gpuNum); super(command, codeDir, gpuNum);
this.cpuNum = cpuNum; this.cpuNum = cpuNum;
this.memoryMB = memoryMB; this.memoryMB = memoryMB;
...@@ -43,5 +44,6 @@ export class NNIPAIK8STrialConfig extends TrialConfig { ...@@ -43,5 +44,6 @@ export class NNIPAIK8STrialConfig extends TrialConfig {
this.nniManagerNFSMountPath = nniManagerNFSMountPath; this.nniManagerNFSMountPath = nniManagerNFSMountPath;
this.containerNFSMountPath = containerNFSMountPath; this.containerNFSMountPath = containerNFSMountPath;
this.paiStoragePlugin = paiStoragePlugin; this.paiStoragePlugin = paiStoragePlugin;
this.paiConfigPath = paiConfigPath;
} }
} }
...@@ -44,6 +44,7 @@ import { PAIClusterConfig, PAITrialJobDetail } from '../paiConfig'; ...@@ -44,6 +44,7 @@ import { PAIClusterConfig, PAITrialJobDetail } from '../paiConfig';
import { PAIJobRestServer } from '../paiJobRestServer'; import { PAIJobRestServer } from '../paiJobRestServer';
const yaml = require('js-yaml'); const yaml = require('js-yaml');
const deepmerge = require('deepmerge');
/** /**
* Training Service implementation for OpenPAI (Open Platform for AI) * Training Service implementation for OpenPAI (Open Platform for AI)
...@@ -189,7 +190,19 @@ class PAIK8STrainingService extends PAITrainingService { ...@@ -189,7 +190,19 @@ class PAIK8STrainingService extends PAITrainingService {
} }
} }
return yaml.safeDump(paiJobConfig); if (this.paiTrialConfig.paiConfigPath) {
try {
const additionalPAIConfig = yaml.safeLoad(fs.readFileSync(this.paiTrialConfig.paiConfigPath, 'utf8'));
//deepmerge(x, y), if an element at the same key is present for both x and y, the value from y will appear in the result.
//refer: https://github.com/TehShrike/deepmerge
const overwriteMerge = (destinationArray: any, sourceArray: any, options: any) => sourceArray;
return yaml.safeDump(deepmerge(additionalPAIConfig, paiJobConfig, { arrayMerge: overwriteMerge }));
} catch (error) {
this.log.error(`Error occurs during loading and merge ${this.paiTrialConfig.paiConfigPath} : ${error}`);
}
} else {
return yaml.safeDump(paiJobConfig);
}
} }
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> { protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
...@@ -258,7 +271,7 @@ class PAIK8STrainingService extends PAITrainingService { ...@@ -258,7 +271,7 @@ class PAIK8STrainingService extends PAITrainingService {
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`); this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiJobConfig = this.generateJobConfigInYamlFormat(trialJobId, nniPaiTrialCommand); const paiJobConfig = this.generateJobConfigInYamlFormat(trialJobId, nniPaiTrialCommand);
this.log.debug(paiJobConfig);
// Step 3. Submit PAI job via Rest call // Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = { const submitJobRequest: request.Options = {
......
...@@ -1112,6 +1112,11 @@ deepmerge@^2.1.1: ...@@ -1112,6 +1112,11 @@ deepmerge@^2.1.1:
version "2.2.1" version "2.2.1"
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-2.2.1.tgz#5d3ff22a01c00f645405a2fbc17d0778a1801170" resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-2.2.1.tgz#5d3ff22a01c00f645405a2fbc17d0778a1801170"
deepmerge@^4.2.2:
version "4.2.2"
resolved "https://registry.yarnpkg.com/deepmerge/-/deepmerge-4.2.2.tgz#44d2ea3679b8f4d4ffba33f03d865fc1e7bf4955"
integrity sha512-FJ3UgI4gIl+PHZm53knsuSFpE+nESMr7M4v9QcgB7S63Kj/6WqMiFQJpBBYz1Pt+66bZpP3Q7Lye0Oo9MPKEdg==
default-require-extensions@^2.0.0: default-require-extensions@^2.0.0:
version "2.0.0" version "2.0.0"
resolved "https://registry.yarnpkg.com/default-require-extensions/-/default-require-extensions-2.0.0.tgz#f5f8fbb18a7d6d50b21f641f649ebb522cfe24f7" resolved "https://registry.yarnpkg.com/default-require-extensions/-/default-require-extensions-2.0.0.tgz#f5f8fbb18a7d6d50b21f641f649ebb522cfe24f7"
......
...@@ -113,7 +113,7 @@ class AGP_Pruner(Pruner): ...@@ -113,7 +113,7 @@ class AGP_Pruner(Pruner):
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask return mask
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask w_abs = weight.abs() * mask['weight']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)} new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask}) self.mask_dict.update({op_name: new_mask})
......
...@@ -271,16 +271,17 @@ pai_yarn_config_schema = { ...@@ -271,16 +271,17 @@ pai_yarn_config_schema = {
pai_trial_schema = { pai_trial_schema = {
'trial':{ 'trial':{
'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'cpuNum': setNumberRange('cpuNum', int, 0, 99999),
'memoryMB': setType('memoryMB', int),
'image': setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'), 'nniManagerNFSMountPath': setPathCheck('nniManagerNFSMountPath'),
'containerNFSMountPath': setType('containerNFSMountPath', str), 'containerNFSMountPath': setType('containerNFSMountPath', str),
'paiStoragePlugin': setType('paiStoragePlugin', str) 'command': setType('command', str),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
Optional('memoryMB'): setType('memoryMB', int),
Optional('image'): setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
Optional('paiStoragePlugin'): setType('paiStoragePlugin', str),
Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
} }
} }
......
...@@ -9,7 +9,7 @@ import random ...@@ -9,7 +9,7 @@ import random
import site import site
import time import time
import tempfile import tempfile
from subprocess import Popen, check_call, CalledProcessError from subprocess import Popen, check_call, CalledProcessError, PIPE, STDOUT
from nni_annotation import expand_annotations, generate_search_space from nni_annotation import expand_annotations, generate_search_space
from nni.constants import ModuleName, AdvisorModuleName from nni.constants import ModuleName, AdvisorModuleName
from .launcher_utils import validate_all_content from .launcher_utils import validate_all_content
...@@ -20,7 +20,7 @@ from .common_utils import get_yml_content, get_json_content, print_error, print_ ...@@ -20,7 +20,7 @@ from .common_utils import get_yml_content, get_json_content, print_error, print_
detect_port, get_user, get_python_dir detect_port, get_user, get_python_dir
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER, PACKAGE_REQUIREMENTS
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment, set_monitor from .nnictl_utils import update_experiment
def get_log_path(config_file_name): def get_log_path(config_file_name):
'''generate stdout and stderr log path''' '''generate stdout and stderr log path'''
...@@ -78,17 +78,17 @@ def get_nni_installation_path(): ...@@ -78,17 +78,17 @@ def get_nni_installation_path():
print_error('Fail to find nni under python library') print_error('Fail to find nni under python library')
exit(1) exit(1)
def start_rest_server(port, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None): def start_rest_server(args, platform, mode, config_file_name, experiment_id=None, log_dir=None, log_level=None):
'''Run nni manager process''' '''Run nni manager process'''
if detect_port(port): if detect_port(args.port):
print_error('Port %s is used by another process, please reset the port!\n' \ print_error('Port %s is used by another process, please reset the port!\n' \
'You could use \'nnictl create --help\' to get help information' % port) 'You could use \'nnictl create --help\' to get help information' % args.port)
exit(1) exit(1)
if (platform != 'local') and detect_port(int(port) + 1): if (platform != 'local') and detect_port(int(args.port) + 1):
print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \ print_error('PAI mode need an additional adjacent port %d, and the port %d is used by another process!\n' \
'You could set another port to start experiment!\n' \ 'You could set another port to start experiment!\n' \
'You could use \'nnictl create --help\' to get help information' % ((int(port) + 1), (int(port) + 1))) 'You could use \'nnictl create --help\' to get help information' % ((int(args.port) + 1), (int(args.port) + 1)))
exit(1) exit(1)
print_normal('Starting restful server...') print_normal('Starting restful server...')
...@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -99,7 +99,7 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
node_command = 'node' node_command = 'node'
if sys.platform == 'win32': if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe') node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform] cmds = [node_command, entry_file, '--port', str(args.port), '--mode', platform]
if mode == 'view': if mode == 'view':
cmds += ['--start_mode', 'resume'] cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true'] cmds += ['--readonly', 'true']
...@@ -111,6 +111,8 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -111,6 +111,8 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
cmds += ['--log_level', log_level] cmds += ['--log_level', log_level]
if mode in ['resume', 'view']: if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id] cmds += ['--experiment_id', experiment_id]
if args.foreground:
cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(config_file_name)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
...@@ -120,9 +122,15 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -120,9 +122,15 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
stderr_file.write(log_header) stderr_file.write(log_header)
if sys.platform == 'win32': if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP from subprocess import CREATE_NEW_PROCESS_GROUP
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP) if args.foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=STDOUT, creationflags=CREATE_NEW_PROCESS_GROUP)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file, creationflags=CREATE_NEW_PROCESS_GROUP)
else: else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file) if args.foreground:
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now) return process, str(time_now)
def set_trial_config(experiment_config, port, config_file_name): def set_trial_config(experiment_config, port, config_file_name):
...@@ -424,7 +432,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -424,7 +432,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True): if log_level not in ['trace', 'debug'] and (args.debug or experiment_config.get('debug') is True):
log_level = 'debug' log_level = 'debug'
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args, experiment_config['trainingServicePlatform'], \
mode, config_file_name, experiment_id, log_dir, log_level) mode, config_file_name, experiment_id, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation # Deal with annotation
...@@ -493,8 +501,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -493,8 +501,14 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
experiment_config['experimentName']) experiment_config['experimentName'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if args.watch: if args.foreground:
set_monitor(True, 3, args.port, rest_process.pid) try:
while True:
log_content = rest_process.stdout.readline().strip().decode('utf-8')
print(log_content)
except KeyboardInterrupt:
kill_command(rest_process.pid)
print_normal('Stopping experiment...')
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
......
...@@ -7,7 +7,7 @@ from schema import SchemaError ...@@ -7,7 +7,7 @@ from schema import SchemaError
from schema import Schema from schema import Schema
from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\ from .config_schema import LOCAL_CONFIG_SCHEMA, REMOTE_CONFIG_SCHEMA, PAI_CONFIG_SCHEMA, PAI_YARN_CONFIG_SCHEMA, KUBEFLOW_CONFIG_SCHEMA,\
FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict FRAMEWORKCONTROLLER_CONFIG_SCHEMA, tuner_schema_dict, advisor_schema_dict, assessor_schema_dict
from .common_utils import print_error, print_warning, print_normal from .common_utils import print_error, print_warning, print_normal, get_yml_content
def expand_path(experiment_config, key): def expand_path(experiment_config, key):
'''Change '~' to user home directory''' '''Change '~' to user home directory'''
...@@ -63,6 +63,8 @@ def parse_path(experiment_config, config_path): ...@@ -63,6 +63,8 @@ def parse_path(experiment_config, config_path):
if experiment_config.get('machineList'): if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])): for index in range(len(experiment_config['machineList'])):
expand_path(experiment_config['machineList'][index], 'sshKeyPath') expand_path(experiment_config['machineList'][index], 'sshKeyPath')
if experiment_config['trial'].get('paiConfigPath'):
expand_path(experiment_config['trial'], 'paiConfigPath')
#if users use relative path, convert it to absolute path #if users use relative path, convert it to absolute path
root_path = os.path.dirname(config_path) root_path = os.path.dirname(config_path)
...@@ -94,6 +96,8 @@ def parse_path(experiment_config, config_path): ...@@ -94,6 +96,8 @@ def parse_path(experiment_config, config_path):
if experiment_config.get('machineList'): if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])): for index in range(len(experiment_config['machineList'])):
parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath') parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
if experiment_config['trial'].get('paiConfigPath'):
parse_relative_path(root_path, experiment_config['trial'], 'paiConfigPath')
def validate_search_space_content(experiment_config): def validate_search_space_content(experiment_config):
'''Validate searchspace content, '''Validate searchspace content,
...@@ -254,6 +258,45 @@ def validate_machine_list(experiment_config): ...@@ -254,6 +258,45 @@ def validate_machine_list(experiment_config):
print_error('Please set machineList!') print_error('Please set machineList!')
exit(1) exit(1)
def validate_pai_config_path(experiment_config):
'''validate paiConfigPath field'''
if experiment_config.get('trainingServicePlatform') == 'pai':
if experiment_config.get('trial', {}).get('paiConfigPath'):
# validate the file format of paiConfigPath, ensure it is yaml format
pai_config = get_yml_content(experiment_config['trial']['paiConfigPath'])
if experiment_config['trial'].get('image') is None:
if pai_config.get('prerequisites', [{}])[0].get('uri') is None:
print_error('Please set image field, or set image uri in your own paiConfig!')
exit(1)
experiment_config['trial']['image'] = pai_config['prerequisites'][0]['uri']
if experiment_config['trial'].get('gpuNum') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('gpu') is None:
print_error('Please set gpuNum field, or set resourcePerInstance gpu in your own paiConfig!')
exit(1)
experiment_config['trial']['gpuNum'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['gpu']
if experiment_config['trial'].get('cpuNum') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('cpu') is None:
print_error('Please set cpuNum field, or set resourcePerInstance cpu in your own paiConfig!')
exit(1)
experiment_config['trial']['cpuNum'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['cpu']
if experiment_config['trial'].get('memoryMB') is None:
if pai_config.get('taskRoles', {}).get('taskrole', {}).get('resourcePerInstance', {}).get('memoryMB', {}) is None:
print_error('Please set memoryMB field, or set resourcePerInstance memoryMB in your own paiConfig!')
exit(1)
experiment_config['trial']['memoryMB'] = pai_config['taskRoles']['taskrole']['resourcePerInstance']['memoryMB']
if experiment_config['trial'].get('paiStoragePlugin') is None:
if pai_config.get('extras', {}).get('com.microsoft.pai.runtimeplugin', [{}])[0].get('plugin') is None:
print_error('Please set paiStoragePlugin field, or set plugin in your own paiConfig!')
exit(1)
experiment_config['trial']['paiStoragePlugin'] = pai_config['extras']['com.microsoft.pai.runtimeplugin'][0]['plugin']
else:
pai_trial_fields_required_list = ['image', 'gpuNum', 'cpuNum', 'memoryMB', 'paiStoragePlugin']
for trial_field in pai_trial_fields_required_list:
if experiment_config['trial'].get(trial_field) is None:
print_error('Please set {0} in trial configuration,\
or set additional pai configuration file path in paiConfigPath!'.format(trial_field))
exit(1)
def validate_pai_trial_conifg(experiment_config): def validate_pai_trial_conifg(experiment_config):
'''validate the trial config in pai platform''' '''validate the trial config in pai platform'''
if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']: if experiment_config.get('trainingServicePlatform') in ['pai', 'paiYarn']:
...@@ -269,6 +312,7 @@ def validate_pai_trial_conifg(experiment_config): ...@@ -269,6 +312,7 @@ def validate_pai_trial_conifg(experiment_config):
print_warning(warning_information.format('dataDir')) print_warning(warning_information.format('dataDir'))
if experiment_config.get('trial').get('outputDir'): if experiment_config.get('trial').get('outputDir'):
print_warning(warning_information.format('outputDir')) print_warning(warning_information.format('outputDir'))
validate_pai_config_path(experiment_config)
def validate_all_content(experiment_config, config_path): def validate_all_content(experiment_config, config_path):
'''Validate whether experiment_config is valid''' '''Validate whether experiment_config is valid'''
......
...@@ -51,7 +51,7 @@ def parse_args(): ...@@ -51,7 +51,7 @@ def parse_args():
parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file') parser_start.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server') parser_start.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode') parser_start.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_start.add_argument('--watch', '-w', action='store_true', help=' set watch mode') parser_start.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_start.set_defaults(func=create_experiment) parser_start.set_defaults(func=create_experiment)
# parse resume command # parse resume command
...@@ -59,7 +59,7 @@ def parse_args(): ...@@ -59,7 +59,7 @@ def parse_args():
parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to resume') parser_resume.add_argument('id', nargs='?', help='The id of the experiment you want to resume')
parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server') parser_resume.add_argument('--port', '-p', default=DEFAULT_REST_PORT, dest='port', help='the port of restful server')
parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode') parser_resume.add_argument('--debug', '-d', action='store_true', help=' set debug mode')
parser_resume.add_argument('--watch', '-w', action='store_true', help=' set watch mode') parser_resume.add_argument('--foreground', '-f', action='store_true', help=' set foreground mode, print log content to terminal')
parser_resume.set_defaults(func=resume_experiment) parser_resume.set_defaults(func=resume_experiment)
# parse view command # parse view command
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment