Unverified Commit e1ae623f authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #147 from Microsoft/master

merge master
parents f796c60b 63697ec5
...@@ -203,6 +203,10 @@ machineList: ...@@ -203,6 +203,10 @@ machineList:
__logLevel__ sets log level for the experiment, available log levels are: `trace, debug, info, warning, error, fatal`. The default value is `info`. __logLevel__ sets log level for the experiment, available log levels are: `trace, debug, info, warning, error, fatal`. The default value is `info`.
* __logCollection__
* Description
__logCollection__ set the way to collect log in remote, pai, kubeflow, frameworkcontroller platform. There are two ways to collect log, one way is from `http`, trial keeper will post log content back from http request in this way, but this way may slow down the speed to process logs in trialKeeper. The other way is `none`, trial keeper will not post log content back, and only post job metrics. If your log content is too big, you could consider setting this param be `none`.
* __tuner__ * __tuner__
* Description * Description
...@@ -227,12 +231,17 @@ machineList: ...@@ -227,12 +231,17 @@ machineList:
* __classArgs__ * __classArgs__
__classArgs__ specifies the arguments of tuner algorithm. __classArgs__ specifies the arguments of tuner algorithm.
* __gpuNum__
* __gpuNum__
__gpuNum__ specifies the gpu number to run the tuner process. The value of this field should be a positive number. __gpuNum__ specifies the gpu number to run the tuner process. The value of this field should be a positive number.
Note: users could only specify one way to set tuner, for example, set {tunerName, optimizationMode} or {tunerCommand, tunerCwd}, and could not set them both. Note: users could only specify one way to set tuner, for example, set {tunerName, optimizationMode} or {tunerCommand, tunerCwd}, and could not set them both.
* __includeIntermediateResults__
If __includeIntermediateResults__ is true, the last intermediate result of the trial that is early stopped by assessor is sent to tuner as final result. The default value of __includeIntermediateResults__ is false.
* __assessor__ * __assessor__
* Description * Description
......
**How to Implement TrainingService in NNI**
===
## Overview
TrainingService is a module related to platform management and job schedule in NNI. TrainingService is designed to be easily implemented, we define an abstract class TrainingService as the parent class of all kinds of TrainignService, users just need to inherit the parent class and complete their own clild class if they want to implement customized TrainingService.
## System architecture
![](../img/NNIDesign.jpg)
The brief system architecture of NNI is shown in the picture. NNIManager is the core management module of system, in charge of calling TrainingService to manage trial jobs and the communication between different modules. Dispatcher is a message processing center responsible for message dispatch. TrainingService is a module to manage trial jobs, it communicates with nniManager module, and has different instance according to different training platform. For the time being, NNI supports local platfrom, [remote platfrom](RemoteMachineMode.md), [PAI platfrom](PAIMode.md), [kubeflow platform](KubeflowMode.md) and [FrameworkController platfrom](FrameworkController.md).
In this document, we introduce the brief design of TrainingService. If users want to add a new TrainingService instance, they just need to complete a child class to implement TrainingService, don't need to understand the code detail of NNIManager, Dispatcher or other modules.
## Folder structure of code
NNI's folder structure is shown below:
```
nni
|- deployment
|- docs
|- examaples
|- src
| |- nni_manager
| | |- common
| | |- config
| | |- core
| | |- coverage
| | |- dist
| | |- rest_server
| | |- training_service
| | | |- common
| | | |- kubernetes
| | | |- local
| | | |- pai
| | | |- remote_machine
| | | |- test
| |- sdk
| |- webui
|- test
|- tools
| |-nni_annotation
| |-nni_cmd
| |-nni_gpu_tool
| |-nni_trial_tool
```
`nni/src/` folder stores the most source code of NNI. The code in this folder is related to NNIManager, TrainingService, SDK, WebUI and other modules. Users could find the abstract class of TrainingService in `nni/src/nni_manager/common/trainingService.ts` file, and they should put their own implemented TrainingService in `nni/src/nni_manager/training_service` folder. If users have implemented their own TrainingService code, they should also supplement the unit test of the code, and place them in `nni/src/nni_manager/training_service/test` folder.
## Function annotation of TrainingService
```
abstract class TrainingService {
public abstract listTrialJobs(): Promise<TrialJobDetail[]>;
public abstract getTrialJob(trialJobId: string): Promise<TrialJobDetail>;
public abstract addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
public abstract removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
public abstract submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail>;
public abstract updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail>;
public abstract get isMultiPhaseJobSupported(): boolean;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>;
public abstract cleanUp(): Promise<void>;
public abstract run(): Promise<void>;
}
```
The parent class of TrainingService has a few abstract functions, users need to inherit the parent class and implement all of these abstract functions.
__setClusterMetadata(key: string, value: string)__
ClusterMetadata is the data related to platform details, for examples, the ClusterMetadata defined in remote machine server is:
```
export class RemoteMachineMeta {
public readonly ip : string;
public readonly port : number;
public readonly username : string;
public readonly passwd?: string;
public readonly sshKeyPath?: string;
public readonly passphrase?: string;
public gpuSummary : GPUSummary | undefined;
/* GPU Reservation info, the key is GPU index, the value is the job id which reserves this GPU*/
public gpuReservation : Map<number, string>;
constructor(ip : string, port : number, username : string, passwd : string,
sshKeyPath : string, passphrase : string) {
this.ip = ip;
this.port = port;
this.username = username;
this.passwd = passwd;
this.sshKeyPath = sshKeyPath;
this.passphrase = passphrase;
this.gpuReservation = new Map<number, string>();
}
}
```
The metadata includes the host address, the username or other configuration related to the platform. Users need to define their own metadata format, and set the metadata instance in this function. This function is called before the experiment is started to set the configuration of remote machines.
__getClusterMetadata(key: string)__
This function will return the metadata value according to the values, it could be left empty if users don't need to use it.
__submitTrialJob(form: JobApplicationForm)__
SubmitTrialJob is a function to submit new trial jobs, users should generate a job instance in TrialJobDetail type. TrialJobDetail is defined as follow:
```
interface TrialJobDetail {
readonly id: string;
readonly status: TrialJobStatus;
readonly submitTime: number;
readonly startTime?: number;
readonly endTime?: number;
readonly tags?: string[];
readonly url?: string;
readonly workingDirectory: string;
readonly form: JobApplicationForm;
readonly sequenceId: number;
isEarlyStopped?: boolean;
}
```
According to different kinds of implementation, users could put the job detail into a job queue, and keep fetching the job from the queue and start preparing and running them. Or they could finish preparing and running process in this function, and return job detail after the submit work.
__cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean)__
If this function is called, the trial started by the platform should be canceled. Different kind of platform has diffenent methods to calcel a running job, this function should be implemented according to specific platform.
__updateTrialJob(trialJobId: string, form: JobApplicationForm)__
This function is called to update the trial job's status, trial job's status should be detected according to different platform, and be updated to `RUNNING`, `SUCCEED`, `FAILED` etc.
__getTrialJob(trialJobId: string)__
This function returns a trialJob detail instance according to trialJobId.
__listTrialJobs()__
Users should put all of trial job detail information into a list, and return the list.
__addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void)__
NNI will hold an EventEmitter to get job metrics, if there is new job metrics detected, the EventEmitter will be triggered. Users should start the EventEmitter in this function.
__removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void)__
Close the EventEmitter.
__run()__
The run() function is a main loop function in TrainingService, users could set a while loop to execute their logic code, and finish executing them when the experiment is stopped.
__cleanUp()__
This function is called to clean up the environment when a experiment is stopped. Users should do the platform-related cleaning operation in this function.
## TrialKeeper tool
NNI offers a TrialKeeper tool to help maintaining trial jobs. Users can find the source code in `nni/tools/nni_trial_tool`. If users want to run trial jobs in cloud platform, this tool will be a fine choice to help keeping trial running in the platform.
The running architecture of TrialKeeper is show as follow:
![](../img/trialkeeper.jpg)
When users submit a trial job to cloud platform, they should wrap their trial command into TrialKeeper, and start a TrialKeeper process in cloud platform. Notice that TrialKeeper use restful server to communicate with TrainingService, users should start a restful server in local machine to receive metrics sent from TrialKeeper. The source code about restful server could be found in `nni/src/nni_manager/training_service/common/clusterJobRestServer.ts`.
## Reference
For more information about how to debug, please [refer](HowToDebug.md).
The guide line of how to contribute, please [refer](CONTRIBUTING).
\ No newline at end of file
...@@ -8,4 +8,5 @@ References ...@@ -8,4 +8,5 @@ References
Python API <sdk_reference> Python API <sdk_reference>
Annotation <AnnotationSpec> Annotation <AnnotationSpec>
Configuration<ExperimentConfig> Configuration<ExperimentConfig>
Search Space <SearchSpaceSpec> Search Space <SearchSpaceSpec>
\ No newline at end of file TrainingService <HowToImplementTrainingService>
\ No newline at end of file
...@@ -37,6 +37,7 @@ interface ExperimentParams { ...@@ -37,6 +37,7 @@ interface ExperimentParams {
multiPhase?: boolean; multiPhase?: boolean;
multiThread?: boolean; multiThread?: boolean;
versionCheck?: boolean; versionCheck?: boolean;
logCollection?: string;
tuner?: { tuner?: {
className: string; className: string;
builtinTunerName?: string; builtinTunerName?: string;
...@@ -45,6 +46,7 @@ interface ExperimentParams { ...@@ -45,6 +46,7 @@ interface ExperimentParams {
classFileName?: string; classFileName?: string;
checkpointDir: string; checkpointDir: string;
gpuNum?: number; gpuNum?: number;
includeIntermediateResults?: boolean;
}; };
assessor?: { assessor?: {
className: string; className: string;
......
...@@ -131,6 +131,10 @@ class NNIManager implements Manager { ...@@ -131,6 +131,10 @@ class NNIManager implements Manager {
if (expParams.versionCheck !== undefined) { if (expParams.versionCheck !== undefined) {
this.trainingService.setClusterMetadata('version_check', expParams.versionCheck.toString()); this.trainingService.setClusterMetadata('version_check', expParams.versionCheck.toString());
} }
// Set up logCollection config
if (expParams.logCollection !== undefined) {
this.trainingService.setClusterMetadata('log_collection', expParams.logCollection.toString());
}
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor, const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor,
expParams.multiPhase, expParams.multiThread); expParams.multiPhase, expParams.multiThread);
...@@ -273,11 +277,17 @@ class NNIManager implements Manager { ...@@ -273,11 +277,17 @@ class NNIManager implements Manager {
newCwd = cwd; newCwd = cwd;
} }
// TO DO: add CUDA_VISIBLE_DEVICES // TO DO: add CUDA_VISIBLE_DEVICES
let includeIntermediateResultsEnv: boolean | undefined = false;
if (this.experimentProfile.params.tuner !== undefined) {
includeIntermediateResultsEnv = this.experimentProfile.params.tuner.includeIntermediateResults;
}
let nniEnv = { let nniEnv = {
NNI_MODE: mode, NNI_MODE: mode,
NNI_CHECKPOINT_DIRECTORY: dataDirectory, NNI_CHECKPOINT_DIRECTORY: dataDirectory,
NNI_LOG_DIRECTORY: getLogDir(), NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel() NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv
}; };
let newEnv = Object.assign({}, process.env, nniEnv); let newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = spawn(command, [], { const tunerProc: ChildProcess = spawn(command, [], {
...@@ -630,7 +640,7 @@ class NNIManager implements Manager { ...@@ -630,7 +640,7 @@ class NNIManager implements Manager {
} }
private async onTunerCommand(commandType: string, content: string): Promise<void> { private async onTunerCommand(commandType: string, content: string): Promise<void> {
this.log.info(`NNIManaer received command from dispatcher: ${commandType}, ${content}`); this.log.info(`NNIManager received command from dispatcher: ${commandType}, ${content}`);
switch (commandType) { switch (commandType) {
case INITIALIZED: case INITIALIZED:
// Tuner is intialized, search space is set, request tuner to generate hyper parameters // Tuner is intialized, search space is set, request tuner to generate hyper parameters
......
...@@ -106,7 +106,7 @@ describe('core/ipcInterface.terminate', (): void => { ...@@ -106,7 +106,7 @@ describe('core/ipcInterface.terminate', (): void => {
assert.ok(!procError); assert.ok(!procError);
deferred.resolve(); deferred.resolve();
}, },
2000); 5000);
return deferred.promise; return deferred.promise;
}); });
......
...@@ -142,6 +142,7 @@ export namespace ValidationSchemas { ...@@ -142,6 +142,7 @@ export namespace ValidationSchemas {
multiPhase: joi.boolean(), multiPhase: joi.boolean(),
multiThread: joi.boolean(), multiThread: joi.boolean(),
versionCheck: joi.boolean(), versionCheck: joi.boolean(),
logCollection: joi.string(),
advisor: joi.object({ advisor: joi.object({
builtinAdvisorName: joi.string().valid('Hyperband'), builtinAdvisorName: joi.string().valid('Hyperband'),
codeDir: joi.string(), codeDir: joi.string(),
...@@ -158,7 +159,8 @@ export namespace ValidationSchemas { ...@@ -158,7 +159,8 @@ export namespace ValidationSchemas {
className: joi.string(), className: joi.string(),
classArgs: joi.any(), classArgs: joi.any(),
gpuNum: joi.number().min(0), gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('') checkpointDir: joi.string().allow(''),
includeIntermediateResults: joi.boolean()
}), }),
assessor: joi.object({ assessor: joi.object({
builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'), builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'),
......
...@@ -32,5 +32,6 @@ export enum TrialConfigMetadataKey { ...@@ -32,5 +32,6 @@ export enum TrialConfigMetadataKey {
KUBEFLOW_CLUSTER_CONFIG = 'kubeflow_config', KUBEFLOW_CLUSTER_CONFIG = 'kubeflow_config',
NNI_MANAGER_IP = 'nni_manager_ip', NNI_MANAGER_IP = 'nni_manager_ip',
FRAMEWORKCONTROLLER_CLUSTER_CONFIG = 'frameworkcontroller_config', FRAMEWORKCONTROLLER_CLUSTER_CONFIG = 'frameworkcontroller_config',
VERSION_CHECK = 'version_check' VERSION_CHECK = 'version_check',
LOG_COLLECTION = 'log_collection'
} }
...@@ -270,6 +270,9 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -270,6 +270,9 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
case TrialConfigMetadataKey.VERSION_CHECK: case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True'); this.versionCheck = (value === 'true' || value === 'True');
break; break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
default: default:
break; break;
} }
......
...@@ -320,6 +320,9 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -320,6 +320,9 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
case TrialConfigMetadataKey.VERSION_CHECK: case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True'); this.versionCheck = (value === 'true' || value === 'True');
break; break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
default: default:
break; break;
} }
......
...@@ -71,5 +71,5 @@ mkdir -p $NNI_OUTPUT_DIR ...@@ -71,5 +71,5 @@ mkdir -p $NNI_OUTPUT_DIR
cp -rT $NNI_CODE_DIR $NNI_SYS_DIR cp -rT $NNI_CODE_DIR $NNI_SYS_DIR
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
sh install_nni.sh sh install_nni.sh
python3 -m nni_trial_tool.trial_keeper --trial_command '{8}' --nnimanager_ip {9} --nnimanager_port {10} --version '{11}'` python3 -m nni_trial_tool.trial_keeper --trial_command '{8}' --nnimanager_ip {9} --nnimanager_port {10} --version '{11}' --log_collection '{12}'`
+ `1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr` + `1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr`
...@@ -62,6 +62,7 @@ abstract class KubernetesTrainingService { ...@@ -62,6 +62,7 @@ abstract class KubernetesTrainingService {
protected kubernetesJobRestServer?: KubernetesJobRestServer; protected kubernetesJobRestServer?: KubernetesJobRestServer;
protected kubernetesClusterConfig?: KubernetesClusterConfig; protected kubernetesClusterConfig?: KubernetesClusterConfig;
protected versionCheck?: boolean = true; protected versionCheck?: boolean = true;
protected logCollection: string;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
...@@ -72,6 +73,7 @@ abstract class KubernetesTrainingService { ...@@ -72,6 +73,7 @@ abstract class KubernetesTrainingService {
this.nextTrialSequenceId = -1; this.nextTrialSequenceId = -1;
this.CONTAINER_MOUNT_PATH = '/tmp/mount'; this.CONTAINER_MOUNT_PATH = '/tmp/mount';
this.genericK8sClient = new GeneralK8sClient(); this.genericK8sClient = new GeneralK8sClient();
this.logCollection = 'none';
} }
public generatePodResource(memory: number, cpuNum: number, gpuNum: number) { public generatePodResource(memory: number, cpuNum: number, gpuNum: number) {
...@@ -204,7 +206,8 @@ abstract class KubernetesTrainingService { ...@@ -204,7 +206,8 @@ abstract class KubernetesTrainingService {
command, command,
nniManagerIp, nniManagerIp,
this.kubernetesRestServerPort, this.kubernetesRestServerPort,
version version,
this.logCollection
); );
return Promise.resolve(runScript); return Promise.resolve(runScript);
} }
......
...@@ -64,7 +64,7 @@ export const PAI_TRIAL_COMMAND_FORMAT: string = ...@@ -64,7 +64,7 @@ export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} `export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4}
&& cd $NNI_SYS_DIR && sh install_nni.sh && cd $NNI_SYS_DIR && sh install_nni.sh
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' && python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}'
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' --version '{12}'`; --pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' --version '{12}' --log_collection '{13}'`;
export const PAI_OUTPUT_DIR_FORMAT: string = export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`; `hdfs://{0}:9000/`;
......
...@@ -76,6 +76,7 @@ class PAITrainingService implements TrainingService { ...@@ -76,6 +76,7 @@ class PAITrainingService implements TrainingService {
private nniManagerIpConfig?: NNIManagerIpConfig; private nniManagerIpConfig?: NNIManagerIpConfig;
private copyExpCodeDirPromise?: Promise<void>; private copyExpCodeDirPromise?: Promise<void>;
private versionCheck?: boolean = true; private versionCheck?: boolean = true;
private logCollection: string;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
...@@ -88,6 +89,7 @@ class PAITrainingService implements TrainingService { ...@@ -88,6 +89,7 @@ class PAITrainingService implements TrainingService {
this.hdfsDirPattern = 'hdfs://(?<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?<baseDir>/.*)?'; this.hdfsDirPattern = 'hdfs://(?<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?<baseDir>/.*)?';
this.nextTrialSequenceId = -1; this.nextTrialSequenceId = -1;
this.paiTokenUpdateInterval = 7200000; //2hours this.paiTokenUpdateInterval = 7200000; //2hours
this.logCollection = 'none';
this.log.info('Construct OpenPAI training service.'); this.log.info('Construct OpenPAI training service.');
} }
...@@ -228,7 +230,8 @@ class PAITrainingService implements TrainingService { ...@@ -228,7 +230,8 @@ class PAITrainingService implements TrainingService {
this.hdfsOutputHost, this.hdfsOutputHost,
this.paiClusterConfig.userName, this.paiClusterConfig.userName,
HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName), HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName),
version version,
this.logCollection
).replace(/\r\n|\n|\r/gm, ''); ).replace(/\r\n|\n|\r/gm, '');
console.log(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`); console.log(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
...@@ -442,6 +445,9 @@ class PAITrainingService implements TrainingService { ...@@ -442,6 +445,9 @@ class PAITrainingService implements TrainingService {
case TrialConfigMetadataKey.VERSION_CHECK: case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True'); this.versionCheck = (value === 'true' || value === 'True');
break; break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
default: default:
//Reject for unknown keys //Reject for unknown keys
throw new Error(`Uknown key: ${key}`); throw new Error(`Uknown key: ${key}`);
......
...@@ -250,8 +250,8 @@ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={ ...@@ -250,8 +250,8 @@ export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
sh install_nni.sh sh install_nni.sh
echo $$ >{6} echo $$ >{6}
python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' --version '{10}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr python3 -m nni_trial_tool.trial_keeper --trial_command '{7}' --nnimanager_ip '{8}' --nnimanager_port '{9}' --version '{10}' --log_collection '{11}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr
echo $? \`date +%s%3N\` >{11}`; echo $? \`date +%s%3N\` >{12}`;
export const HOST_JOB_SHELL_FORMAT: string = export const HOST_JOB_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
......
...@@ -77,6 +77,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -77,6 +77,7 @@ class RemoteMachineTrainingService implements TrainingService {
private readonly remoteOS: string; private readonly remoteOS: string;
private nniManagerIpConfig?: NNIManagerIpConfig; private nniManagerIpConfig?: NNIManagerIpConfig;
private versionCheck: boolean = true; private versionCheck: boolean = true;
private logCollection: string;
constructor(@component.Inject timer: ObservableTimer) { constructor(@component.Inject timer: ObservableTimer) {
this.remoteOS = 'linux'; this.remoteOS = 'linux';
...@@ -91,6 +92,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -91,6 +92,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.timer = timer; this.timer = timer;
this.log = getLogger(); this.log = getLogger();
this.trialSequenceId = -1; this.trialSequenceId = -1;
this.logCollection = 'none';
this.log.info('Construct remote machine training service.'); this.log.info('Construct remote machine training service.');
} }
...@@ -376,6 +378,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -376,6 +378,9 @@ class RemoteMachineTrainingService implements TrainingService {
case TrialConfigMetadataKey.VERSION_CHECK: case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True'); this.versionCheck = (value === 'true' || value === 'True');
break; break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
default: default:
//Reject for unknown keys //Reject for unknown keys
throw new Error(`Uknown key: ${key}`); throw new Error(`Uknown key: ${key}`);
...@@ -598,6 +603,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -598,6 +603,7 @@ class RemoteMachineTrainingService implements TrainingService {
nniManagerIp, nniManagerIp,
this.remoteRestServerPort, this.remoteRestServerPort,
version, version,
this.logCollection,
path.join(trialWorkingFolder, '.nni', 'code') path.join(trialWorkingFolder, '.nni', 'code')
) )
......
...@@ -18,14 +18,15 @@ ...@@ -18,14 +18,15 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ================================================================================================== # ==================================================================================================
import os
import logging import logging
from collections import defaultdict from collections import defaultdict
import json_tricks import json_tricks
import threading
from .protocol import CommandType, send from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult from .assessor import AssessResult
from .common import multi_thread_enabled
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -70,7 +71,7 @@ def _pack_parameter(parameter_id, params, customized=False): ...@@ -70,7 +71,7 @@ def _pack_parameter(parameter_id, params, customized=False):
class MsgDispatcher(MsgDispatcherBase): class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None): def __init__(self, tuner, assessor=None):
super().__init__() super(MsgDispatcher, self).__init__()
self.tuner = tuner self.tuner = tuner
self.assessor = assessor self.assessor = assessor
if assessor is None: if assessor is None:
...@@ -87,9 +88,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -87,9 +88,8 @@ class MsgDispatcher(MsgDispatcherBase):
self.assessor.save_checkpoint() self.assessor.save_checkpoint()
def handle_initialize(self, data): def handle_initialize(self, data):
''' """Data is search space
data is search space """
'''
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
send(CommandType.Initialized, '') send(CommandType.Initialized, '')
return True return True
...@@ -126,12 +126,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -126,12 +126,7 @@ class MsgDispatcher(MsgDispatcherBase):
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
if data['type'] == 'FINAL': if data['type'] == 'FINAL':
id_ = data['parameter_id'] self._handle_final_metric_data(data)
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
elif data['type'] == 'PERIODICAL': elif data['type'] == 'PERIODICAL':
if self.assessor is not None: if self.assessor is not None:
self._handle_intermediate_metric_data(data) self._handle_intermediate_metric_data(data)
...@@ -157,7 +152,19 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -157,7 +152,19 @@ class MsgDispatcher(MsgDispatcherBase):
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
return True return True
def _handle_final_metric_data(self, data):
"""Call tuner to process final results
"""
id_ = data['parameter_id']
value = data['value']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results
"""
if data['type'] != 'PERIODICAL': if data['type'] != 'PERIODICAL':
return True return True
if self.assessor is None: if self.assessor is None:
...@@ -187,5 +194,20 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -187,5 +194,20 @@ class MsgDispatcher(MsgDispatcherBase):
if result is AssessResult.Bad: if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id) _logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id)) send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS'))
if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true':
self._earlystop_notify_tuner(data)
else: else:
_logger.debug('GOOD') _logger.debug('GOOD')
def _earlystop_notify_tuner(self, data):
"""Send last intermediate result as final result to tuner in case the
trial is early stopped.
"""
_logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = 'FINAL'
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
self.enqueue_command(CommandType.ReportMetricData, data)
...@@ -19,14 +19,13 @@ ...@@ -19,14 +19,13 @@
# ================================================================================================== # ==================================================================================================
#import json_tricks #import json_tricks
import logging
import os import os
from queue import Queue import threading
import sys import logging
from multiprocessing.dummy import Pool as ThreadPool from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty
import json_tricks import json_tricks
from .common import init_logger, multi_thread_enabled from .common import init_logger, multi_thread_enabled
from .recoverable import Recoverable from .recoverable import Recoverable
from .protocol import CommandType, receive from .protocol import CommandType, receive
...@@ -34,57 +33,109 @@ from .protocol import CommandType, receive ...@@ -34,57 +33,109 @@ from .protocol import CommandType, receive
init_logger('dispatcher.log') init_logger('dispatcher.log')
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20
_worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable): class MsgDispatcherBase(Recoverable):
def __init__(self): def __init__(self):
if multi_thread_enabled(): if multi_thread_enabled():
self.pool = ThreadPool() self.pool = ThreadPool()
self.thread_results = [] self.thread_results = []
else:
self.stopping = False
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []
def run(self): def run(self):
"""Run the tuner. """Run the tuner.
This function will never return unless raise. This function will never return unless raise.
""" """
_logger.info('Start dispatcher')
mode = os.getenv('NNI_MODE') mode = os.getenv('NNI_MODE')
if mode == 'resume': if mode == 'resume':
self.load_checkpoint() self.load_checkpoint()
while True: while True:
_logger.debug('waiting receive_message')
command, data = receive() command, data = receive()
if data:
data = json_tricks.loads(data)
if command is None or command is CommandType.Terminate: if command is None or command is CommandType.Terminate:
break break
if multi_thread_enabled(): if multi_thread_enabled():
result = self.pool.map_async(self.handle_request_thread, [(command, data)]) result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result) self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]): if any([thread_result.ready() and not thread_result.successful() for thread_result in self.thread_results]):
_logger.debug('Caught thread exception') _logger.debug('Caught thread exception')
break break
else: else:
self.handle_request((command, data)) self.enqueue_command(command, data)
_logger.info('Dispatcher exiting...')
self.stopping = True
if multi_thread_enabled(): if multi_thread_enabled():
self.pool.close() self.pool.close()
self.pool.join() self.pool.join()
else:
self.default_worker.join()
self.assessor_worker.join()
_logger.info('Terminated by NNI manager') _logger.info('Terminated by NNI manager')
def handle_request_thread(self, request): def command_queue_worker(self, command_queue):
"""Process commands in command queues.
"""
while True:
try:
# set timeout to ensure self.stopping is checked periodically
command, data = command_queue.get(timeout=3)
try:
self.process_command(command, data)
except Exception as e:
_logger.exception(e)
self.worker_exceptions.append(e)
break
except Empty:
pass
if self.stopping and (_worker_fast_exit_on_terminate or command_queue.empty()):
break
def enqueue_command(self, command, data):
"""Enqueue command into command queues
"""
if command == CommandType.TrialEnd or (command == CommandType.ReportMetricData and data['type'] == 'PERIODICAL'):
self.assessor_command_queue.put((command, data))
else:
self.default_command_queue.put((command, data))
qsize = self.default_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('default queue length: %d', qsize)
qsize = self.assessor_command_queue.qsize()
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('assessor queue length: %d', qsize)
def process_command_thread(self, request):
"""Worker thread to process a command.
"""
command, data = request
if multi_thread_enabled(): if multi_thread_enabled():
try: try:
self.handle_request(request) self.process_command(command, data)
except Exception as e: except Exception as e:
_logger.exception(str(e)) _logger.exception(str(e))
raise raise
else: else:
pass pass
def handle_request(self, request): def process_command(self, command, data):
command, data = request _logger.debug('process_command: command: [{}], data: [{}]'.format(command, data))
_logger.debug('handle request: command: [{}], data: [{}]'.format(command, data))
if data:
data = json_tricks.loads(data)
command_handlers = { command_handlers = {
# Tunner commands: # Tunner commands:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment