Unverified Commit 0b9d6ce6 authored by Chi Song's avatar Chi Song Committed by GitHub
Browse files

Reuse OpenPAI jobs to run multiple trials (#2521)

Designed new interface to support reusable training service, currently only applies to OpenPAI, and default disabled.

Replace trial_keeper.py to trial_runner.py, trial_runner holds an environment, and receives commands from nni manager to run or stop an trial, and return events to nni manager.
Add trial dispatcher, which inherits from original trianing service interface. It uses to share as many as possible code of all training service, and isolate with training services.
Add EnvironmentService interface to manage environment, including start/stop an environment, refresh status of environments.
Add command channel on both nni manager and trial runner parts, it supports different ways to pass messages between them. Current supported channels are file, web sockets. and supported commands from nni manager are start, kill trial, send new parameters; from runner are initialized(support some channel doesn't know which runner connected), trial end, stdout ((new type), including metric like before), version check (new type), gpu info (new type).
Add storage service to wrapper a storage to standard file operations, like NFS, azure storage and so on.
Partial support run multiple trials in parallel on runner side, but not supported by trial dispatcher side.
Other minor changes,

Add log_level to TS UT, so that UT can show debug level log.
Expose platform to start info.
Add RouterTrainingService to keep origianl OpenPAI training service, and support dynamic IOC binding.
Add more GPU info for future usage, including GPU mem total/free/used, gpu type.
Make some license information consistence.
Fix async/await problems on Array.forEach, this method doesn't support async actually.
Fix IT errors on download data, which causes by my #2484 .
Accelerate some run loop pattern by reducing sleep seconds.
parent 6de15707
...@@ -64,7 +64,8 @@ setuptools.setup( ...@@ -64,7 +64,8 @@ setuptools.setup(
'coverage', 'coverage',
'colorama', 'colorama',
'scikit-learn>=0.20,<0.22', 'scikit-learn>=0.20,<0.22',
'pkginfo' 'pkginfo',
'websockets'
], ],
classifiers = [ classifiers = [
'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3',
......
**Run an Experiment on OpenPAI** **Run an Experiment on OpenPAI**
=== ===
NNI supports running an experiment on [OpenPAI](https://github.com/Microsoft/pai) (aka pai), called pai mode. Before starting to use NNI pai mode, you should have an account to access an [OpenPAI](https://github.com/Microsoft/pai) cluster. See [here](https://github.com/Microsoft/pai#how-to-deploy) if you don't have any OpenPAI account and want to deploy an OpenPAI cluster. In pai mode, your trial program will run in pai's container created by Docker. NNI supports running an experiment on [OpenPAI](https://github.com/Microsoft/pai), called pai mode. Before starting to use NNI pai mode, you should have an account to access an [OpenPAI](https://github.com/Microsoft/pai) cluster. See [here](https://github.com/Microsoft/pai#how-to-deploy) if you don't have any OpenPAI account and want to deploy an OpenPAI cluster. In pai mode, your trial program will run in pai's container created by Docker.
## Setup environment ## Setup environment
Step 1. Install NNI, follow the install guide [here](../Tutorial/QuickStart.md). Step 1. Install NNI, follow the install guide [here](../Tutorial/QuickStart.md).
Step 2. Get PAI token. Step 2. Get token.
Click `My profile` button in the top-right side of PAI's webprotal.
![](../../img/pai_profile.jpg) Open web portal of OpenPAI, and click `My profile` button in the top-right side.
![](../../img/pai_profile.jpg)
Click `copy` button in the page to copy a jwt token. Click `copy` button in the page to copy a jwt token.
![](../../img/pai_token.jpg) ![](../../img/pai_token.jpg)
Step 3. Mount NFS storage to local machine. Step 3. Mount NFS storage to local machine.
Click `Submit job` button in PAI's webportal.
![](../../img/pai_job_submission_page.jpg) Click `Submit job` button in web portal.
Find the data management region in job submission page. ![](../../img/pai_job_submission_page.jpg)
Find the data management region in job submission page.
![](../../img/pai_data_management_page.jpg) ![](../../img/pai_data_management_page.jpg)
The `Preview container paths` is the NFS host and path that PAI provided, you need to mount the corresponding host and path to your local machine first, then NNI could use the PAI's NFS storage.
The `Preview container paths` is the NFS host and path that OpenPAI provided, you need to mount the corresponding host and path to your local machine first, then NNI could use the OpenPAI's NFS storage.
For example, use the following command: For example, use the following command:
```
```bash
sudo mount -t nfs4 gcr-openpai-infra02:/pai/data /local/mnt sudo mount -t nfs4 gcr-openpai-infra02:/pai/data /local/mnt
``` ```
Then the `/data` folder in container will be mounted to `/local/mnt` folder in your local machine. Then the `/data` folder in container will be mounted to `/local/mnt` folder in your local machine.
You could use the following configuration in your NNI's config file: You could use the following configuration in your NNI's config file:
```
```yaml
nniManagerNFSMountPath: /local/mnt nniManagerNFSMountPath: /local/mnt
``` ```
Step 4. Get PAI's storage config name and nniManagerMountPath Step 4. Get OpenPAI's storage config name and nniManagerMountPath
The `Team share storage` field is storage configuration used to specify storage value in PAI. You can get `paiStorageConfigName` and `containerNFSMountPath` field in `Team share storage`, for example:
``` The `Team share storage` field is storage configuration used to specify storage value in OpenPAI. You can get `paiStorageConfigName` and `containerNFSMountPath` field in `Team share storage`, for example:
```yaml
paiStorageConfigName: confignfs-data paiStorageConfigName: confignfs-data
containerNFSMountPath: /mnt/confignfs-data containerNFSMountPath: /mnt/confignfs-data
``` ```
...@@ -73,44 +83,86 @@ paiConfig: ...@@ -73,44 +83,86 @@ paiConfig:
userName: your_pai_nni_user userName: your_pai_nni_user
token: your_pai_token token: your_pai_token
host: 10.1.1.1 host: 10.1.1.1
# optional, experimental feature.
reuse: true
``` ```
Note: You should set `trainingServicePlatform: pai` in NNI config YAML file if you want to start experiment in pai mode. The host field in configuration file is PAI's job submission page uri, like `10.10.5.1`, the default http protocol in NNI is `http`, if your PAI's cluster enabled https, please use the uri in `https://10.10.5.1` format. Note: You should set `trainingServicePlatform: pai` in NNI config YAML file if you want to start experiment in pai mode.
### Trial configurations
Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMode.md), `trial` configuration in pai mode have these additional keys:
Compared with [LocalMode](LocalMode.md) and [RemoteMachineMode](RemoteMachineMode.md), trial configuration in pai mode have these additional keys:
* cpuNum * cpuNum
* Optional key. Should be positive number based on your trial program's CPU requirement. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
Optional key. Should be positive number based on your trial program's CPU requirement. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* memoryMB * memoryMB
* Optional key. Should be positive number based on your trial program's memory requirement. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
Optional key. Should be positive number based on your trial program's memory requirement. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* image * image
* Optional key. In pai mode, your trial program will be scheduled by OpenPAI to run in [Docker container](https://www.docker.com/). This key is used to specify the Docker image used to create the container in which your trial will run.
* We already build a docker image [nnimsra/nni](https://hub.docker.com/r/msranni/nni/) on [Docker Hub](https://hub.docker.com/). It contains NNI python packages, Node modules and javascript artifact files required to start experiment, and all of NNI dependencies. The docker file used to build this image can be found at [here](https://github.com/Microsoft/nni/tree/master/deployment/docker/Dockerfile). You can either use this image directly in your config file, or build your own image based on it. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field. Optional key. In pai mode, your trial program will be scheduled by OpenPAI to run in [Docker container](https://www.docker.com/). This key is used to specify the Docker image used to create the container in which your trial will run.
We already build a docker image [nnimsra/nni](https://hub.docker.com/r/msranni/nni/) on [Docker Hub](https://hub.docker.com/). It contains NNI python packages, Node modules and javascript artifact files required to start experiment, and all of NNI dependencies. The docker file used to build this image can be found at [here](https://github.com/Microsoft/nni/tree/master/deployment/docker/Dockerfile). You can either use this image directly in your config file, or build your own image based on it. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* virtualCluster * virtualCluster
* Optional key. Set the virtualCluster of OpenPAI. If omitted, the job will run on default virtual cluster.
Optional key. Set the virtualCluster of OpenPAI. If omitted, the job will run on default virtual cluster.
* nniManagerNFSMountPath * nniManagerNFSMountPath
* Required key. Set the mount path in your nniManager machine.
Required key. Set the mount path in your nniManager machine.
* containerNFSMountPath * containerNFSMountPath
* Required key. Set the mount path in your container used in PAI.
Required key. Set the mount path in your container used in OpenPAI.
* paiStorageConfigName: * paiStorageConfigName:
* Optional key. Set the storage name used in PAI. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* command Optional key. Set the storage name used in OpenPAI. If it is not set in trial configuration, it should be set in the config file specified in `paiConfigPath` field.
* Optional key. Set the commands used in PAI container.
* command
Optional key. Set the commands used in OpenPAI container.
* paiConfigPath * paiConfigPath
* Optional key. Set the file path of pai job configuration, the file is in yaml format. Optional key. Set the file path of OpenPAI job configuration, the file is in yaml format.
If users set `paiConfigPath` in NNI's configuration file, no need to specify the fields `command`, `paiStorageConfigName`, `virtualCluster`, `image`, `memoryMB`, `cpuNum`, `gpuNum` in `trial` configuration. These fields will use the values from the config file specified by `paiConfigPath`.
``` If users set `paiConfigPath` in NNI's configuration file, no need to specify the fields `command`, `paiStorageConfigName`, `virtualCluster`, `image`, `memoryMB`, `cpuNum`, `gpuNum` in `trial` configuration. These fields will use the values from the config file specified by `paiConfigPath`.
Note:
1. The job name in PAI's configuration file will be replaced by a new job name, the new job name is created by NNI, the name format is nni_exp_${this.experimentId}_trial_${trialJobId}. Note:
1. The job name in OpenPAI's configuration file will be replaced by a new job name, the new job name is created by NNI, the name format is nni_exp_${this.experimentId}_trial_${trialJobId}.
2. If users set multiple taskRoles in PAI's configuration file, NNI will wrap all of these taksRoles and start multiple tasks in one trial job, users should ensure that only one taskRole report metric to NNI, otherwise there might be some conflict error. 2. If users set multiple taskRoles in OpenPAI's configuration file, NNI will wrap all of these taksRoles and start multiple tasks in one trial job, users should ensure that only one taskRole report metric to NNI, otherwise there might be some conflict error.
``` ### OpenPAI configurations
`paiConfig` includes OpenPAI specific configurations,
* userName
Required key. User name of OpenPAI platform.
* token
Required key. Authentication key of OpenPAI platform.
* host
Required key. The host of OpenPAI platform. It's OpenPAI's job submission page uri, like `10.10.5.1`, the default http protocol in NNI is `http`, if your OpenPAI cluster enabled https, please use the uri in `https://10.10.5.1` format.
* reuse (experimental feature)
Optional key, default is false. If it's true, NNI will reuse OpenPAI jobs to run as many as possible trials. It can save time of creating new jobs. User needs to make sure each trial can run independent in same job, for example, avoid loading checkpoint from previous trials.
Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command Once complete to fill NNI experiment config file and save (for example, save as exp_pai.yml), then run the following command
```
```bash
nnictl create --config exp_pai.yml nnictl create --config exp_pai.yml
``` ```
to start the experiment in pai mode. NNI will create OpenPAI job for each trial, and the job name format is something like `nni_exp_{experiment_id}_trial_{trial_id}`. to start the experiment in pai mode. NNI will create OpenPAI job for each trial, and the job name format is something like `nni_exp_{experiment_id}_trial_{trial_id}`.
You can see jobs created by NNI in the OpenPAI cluster's web portal, like: You can see jobs created by NNI in the OpenPAI cluster's web portal, like:
![](../../img/nni_pai_joblist.jpg) ![](../../img/nni_pai_joblist.jpg)
...@@ -128,11 +180,12 @@ And you will be redirected to HDFS web portal to browse the output files of that ...@@ -128,11 +180,12 @@ And you will be redirected to HDFS web portal to browse the output files of that
You can see there're three fils in output folder: stderr, stdout, and trial.log You can see there're three fils in output folder: stderr, stdout, and trial.log
## data management ## data management
Before using NNI to start your experiment, users should set the corresponding mount data path in your nniManager machine. PAI has their own storage(NFS, AzureBlob ...), and the storage will used in PAI will be mounted to the container when it start a job. Users should set the PAI storage type by `paiStorageConfigName` field to choose a storage in PAI. Then users should mount the storage to their nniManager machine, and set the `nniManagerNFSMountPath` field in configuration file, NNI will generate bash files and copy data in `codeDir` to the `nniManagerNFSMountPath` folder, then NNI will start a trial job. The data in `nniManagerNFSMountPath` will be sync to PAI storage, and will be mounted to PAI's container. The data path in container is set in `containerNFSMountPath`, NNI will enter this folder first, and then run scripts to start a trial job. Before using NNI to start your experiment, users should set the corresponding mount data path in your nniManager machine. OpenPAI has their own storage(NFS, AzureBlob ...), and the storage will used in OpenPAI will be mounted to the container when it start a job. Users should set the OpenPAI storage type by `paiStorageConfigName` field to choose a storage in OpenPAI. Then users should mount the storage to their nniManager machine, and set the `nniManagerNFSMountPath` field in configuration file, NNI will generate bash files and copy data in `codeDir` to the `nniManagerNFSMountPath` folder, then NNI will start a trial job. The data in `nniManagerNFSMountPath` will be sync to OpenPAI storage, and will be mounted to OpenPAI's container. The data path in container is set in `containerNFSMountPath`, NNI will enter this folder first, and then run scripts to start a trial job.
## version check ## version check
NNI support version check feature in since version 0.6. It is a policy to insure the version of NNIManager is consistent with trialKeeper, and avoid errors caused by version incompatibility. NNI support version check feature in since version 0.6. It is a policy to insure the version of NNIManager is consistent with trialKeeper, and avoid errors caused by version incompatibility.
Check policy: Check policy:
1. NNIManager before v0.6 could run any version of trialKeeper, trialKeeper support backward compatibility. 1. NNIManager before v0.6 could run any version of trialKeeper, trialKeeper support backward compatibility.
2. Since version 0.6, NNIManager version should keep same with triakKeeper version. For example, if NNIManager version is 0.6, trialKeeper version should be 0.6 too. 2. Since version 0.6, NNIManager version should keep same with triakKeeper version. For example, if NNIManager version is 0.6, trialKeeper version should be 0.6 too.
3. Note that the version check feature only check first two digits of version.For example, NNIManager v0.6.1 could use trialKeeper v0.6 or trialKeeper v0.6.2, but could not use trialKeeper v0.5.1 or trialKeeper v0.7. 3. Note that the version check feature only check first two digits of version.For example, NNIManager v0.6.1 could use trialKeeper v0.6 or trialKeeper v0.6.2, but could not use trialKeeper v0.5.1 or trialKeeper v0.7.
......
...@@ -70,6 +70,7 @@ This document describes the rules to write the config file, and provides some ex ...@@ -70,6 +70,7 @@ This document describes the rules to write the config file, and provides some ex
- [password](#password) - [password](#password)
- [token](#token) - [token](#token)
- [host](#host) - [host](#host)
- [reuse](#reuse)
* [Examples](#examples) * [Examples](#examples)
+ [Local mode](#local-mode) + [Local mode](#local-mode)
+ [Remote mode](#remote-mode) + [Remote mode](#remote-mode)
...@@ -656,6 +657,12 @@ Required. String. ...@@ -656,6 +657,12 @@ Required. String.
The hostname of IP address of PAI. The hostname of IP address of PAI.
#### reuse
Optional. Bool. default: `false`. It's an experimental feature.
If it's true, NNI will reuse OpenPAI jobs to run as many as possible trials. It can save time of creating new jobs. User needs to make sure each trial can run independent in same job, for example, avoid loading checkpoint from previous trials.
## Examples ## Examples
### Local mode ### Local mode
......
...@@ -42,7 +42,8 @@ setup( ...@@ -42,7 +42,8 @@ setup(
'PythonWebHDFS', 'PythonWebHDFS',
'colorama', 'colorama',
'scikit-learn>=0.20,<0.22', 'scikit-learn>=0.20,<0.22',
'pkginfo' 'pkginfo',
'websockets'
], ],
entry_points = { entry_points = {
......
...@@ -17,14 +17,16 @@ class ExperimentStartupInfo { ...@@ -17,14 +17,16 @@ class ExperimentStartupInfo {
private logDir: string = ''; private logDir: string = '';
private logLevel: string = ''; private logLevel: string = '';
private readonly: boolean = false; private readonly: boolean = false;
private platform: string = '';
public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string, readonly?: boolean): void { public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean): void {
assert(!this.initialized); assert(!this.initialized);
assert(experimentId.trim().length > 0); assert(experimentId.trim().length > 0);
this.newExperiment = newExperiment; this.newExperiment = newExperiment;
this.experimentId = experimentId; this.experimentId = experimentId;
this.basePort = basePort; this.basePort = basePort;
this.initialized = true; this.initialized = true;
this.platform = platform;
if (logDir !== undefined && logDir.length > 0) { if (logDir !== undefined && logDir.length > 0) {
this.logDir = path.join(path.normalize(logDir), this.getExperimentId()); this.logDir = path.join(path.normalize(logDir), this.getExperimentId());
...@@ -59,6 +61,12 @@ class ExperimentStartupInfo { ...@@ -59,6 +61,12 @@ class ExperimentStartupInfo {
return this.newExperiment; return this.newExperiment;
} }
public getPlatform(): string {
assert(this.initialized);
return this.platform;
}
public getLogDir(): string { public getLogDir(): string {
assert(this.initialized); assert(this.initialized);
...@@ -90,19 +98,25 @@ function isNewExperiment(): boolean { ...@@ -90,19 +98,25 @@ function isNewExperiment(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment(); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment();
} }
function getPlatform(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getPlatform();
}
function getExperimentStartupInfo(): ExperimentStartupInfo { function getExperimentStartupInfo(): ExperimentStartupInfo {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo);
} }
function setExperimentStartupInfo( function setExperimentStartupInfo(
newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string, readonly?: boolean): void { newExperiment: boolean, experimentId: string, basePort: number, platform: string, logDir?: string, logLevel?: string, readonly?: boolean): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo) component.get<ExperimentStartupInfo>(ExperimentStartupInfo)
.setStartupInfo(newExperiment, experimentId, basePort, logDir, logLevel, readonly); .setStartupInfo(newExperiment, experimentId, basePort, platform, logDir, logLevel, readonly);
} }
function isReadonly(): boolean { function isReadonly(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isReadonly(); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isReadonly();
} }
export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getExperimentStartupInfo, export {
setExperimentStartupInfo, isReadonly }; ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getPlatform, getExperimentStartupInfo,
setExperimentStartupInfo, isReadonly
};
...@@ -19,6 +19,7 @@ import { Database, DataStore } from './datastore'; ...@@ -19,6 +19,7 @@ import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo'; import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { ExperimentParams, Manager } from './manager'; import { ExperimentParams, Manager } from './manager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
import { logLevelNameMap } from './log';
function getExperimentRootDir(): string { function getExperimentRootDir(): string {
return getExperimentStartupInfo() return getExperimentStartupInfo()
...@@ -184,7 +185,12 @@ function prepareUnitTest(): void { ...@@ -184,7 +185,12 @@ function prepareUnitTest(): void {
Container.snapshot(TrainingService); Container.snapshot(TrainingService);
Container.snapshot(Manager); Container.snapshot(Manager);
setExperimentStartupInfo(true, 'unittest', 8080); const logLevel: string = parseArg(['--log_level', '-ll']);
if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) {
console.log(`FATAL: invalid log_level: ${logLevel}`);
}
setExperimentStartupInfo(true, 'unittest', 8080, 'unittest', undefined, logLevel);
mkDirPSync(getLogDir()); mkDirPSync(getLogDir());
const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite'); const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite');
......
...@@ -12,12 +12,30 @@ const TRIAL_END = 'EN'; ...@@ -12,12 +12,30 @@ const TRIAL_END = 'EN';
const TERMINATE = 'TE'; const TERMINATE = 'TE';
const PING = 'PI'; const PING = 'PI';
const GPU_INFO = 'GI';
const STDOUT = 'SO';
const VERSION_CHECK = 'VC';
const INITIALIZED = 'ID'; const INITIALIZED = 'ID';
const NEW_TRIAL_JOB = 'TR'; const NEW_TRIAL_JOB = 'TR';
const SEND_TRIAL_JOB_PARAMETER = 'SP'; const SEND_TRIAL_JOB_PARAMETER = 'SP';
const NO_MORE_TRIAL_JOBS = 'NO'; const NO_MORE_TRIAL_JOBS = 'NO';
const KILL_TRIAL_JOB = 'KI'; const KILL_TRIAL_JOB = 'KI';
const TRIAL_COMMANDS: Set<string> = new Set([
// from ctl to node
NEW_TRIAL_JOB,
SEND_TRIAL_JOB_PARAMETER,
KILL_TRIAL_JOB,
// from node to ctl
INITIALIZED,
TRIAL_END,
GPU_INFO,
STDOUT,
VERSION_CHECK,
]);
const TUNER_COMMANDS: Set<string> = new Set([ const TUNER_COMMANDS: Set<string> = new Set([
INITIALIZE, INITIALIZE,
REQUEST_TRIAL_JOBS, REQUEST_TRIAL_JOBS,
...@@ -53,11 +71,15 @@ export { ...@@ -53,11 +71,15 @@ export {
TRIAL_END, TRIAL_END,
TERMINATE, TERMINATE,
PING, PING,
GPU_INFO,
STDOUT,
VERSION_CHECK,
INITIALIZED, INITIALIZED,
NEW_TRIAL_JOB, NEW_TRIAL_JOB,
NO_MORE_TRIAL_JOBS, NO_MORE_TRIAL_JOBS,
KILL_TRIAL_JOB, KILL_TRIAL_JOB,
TUNER_COMMANDS, TUNER_COMMANDS,
ASSESSOR_COMMANDS, ASSESSOR_COMMANDS,
TRIAL_COMMANDS,
SEND_TRIAL_JOB_PARAMETER SEND_TRIAL_JOB_PARAMETER
}; };
...@@ -135,4 +135,4 @@ function createDispatcherInterface(process: ChildProcess): IpcInterface { ...@@ -135,4 +135,4 @@ function createDispatcherInterface(process: ChildProcess): IpcInterface {
return new IpcInterface(process, new Set([...CommandType.TUNER_COMMANDS, ...CommandType.ASSESSOR_COMMANDS])); return new IpcInterface(process, new Set([...CommandType.TUNER_COMMANDS, ...CommandType.ASSESSOR_COMMANDS]));
} }
export { IpcInterface, createDispatcherInterface }; export { IpcInterface, createDispatcherInterface, encodeCommand, decodeCommand };
...@@ -21,7 +21,7 @@ import { NNIRestServer } from './rest_server/nniRestServer'; ...@@ -21,7 +21,7 @@ import { NNIRestServer } from './rest_server/nniRestServer';
import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService'; import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService';
import { KubeflowTrainingService } from './training_service/kubernetes/kubeflow/kubeflowTrainingService'; import { KubeflowTrainingService } from './training_service/kubernetes/kubeflow/kubeflowTrainingService';
import { LocalTrainingService } from './training_service/local/localTrainingService'; import { LocalTrainingService } from './training_service/local/localTrainingService';
import { PAIK8STrainingService } from './training_service/pai/paiK8S/paiK8STrainingService'; import { RouterTrainingService } from './training_service/reusable/routerTrainingService';
import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTrainingService'; import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTrainingService';
import { import {
RemoteMachineTrainingService RemoteMachineTrainingService
...@@ -29,11 +29,11 @@ import { ...@@ -29,11 +29,11 @@ import {
import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService'; import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService';
function initStartupInfo( function initStartupInfo(
startExpMode: string, resumeExperimentId: string, basePort: number, startExpMode: string, resumeExperimentId: string, basePort: number, platform: string,
logDirectory: string, experimentLogLevel: string, readonly: boolean): void { logDirectory: string, experimentLogLevel: string, readonly: boolean): void {
const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW); const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW);
const expId: string = createNew ? uniqueString(8) : resumeExperimentId; const expId: string = createNew ? uniqueString(8) : resumeExperimentId;
setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly); setExperimentStartupInfo(createNew, expId, basePort, platform, logDirectory, experimentLogLevel, readonly);
} }
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> { async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
...@@ -47,10 +47,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -47,10 +47,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'pai') { } else if (platformMode === 'pai') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(PAIK8STrainingService) .to(RouterTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'paiYarn') { } else if (platformMode === 'paiYarn') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(PAIYarnTrainingService) .to(PAIYarnTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'kubeflow') { } else if (platformMode === 'kubeflow') {
...@@ -153,31 +153,31 @@ if (!('true' || 'false').includes(readonlyArg.toLowerCase())) { ...@@ -153,31 +153,31 @@ if (!('true' || 'false').includes(readonlyArg.toLowerCase())) {
} }
const readonly = readonlyArg.toLowerCase() == 'true' ? true : false; const readonly = readonlyArg.toLowerCase() == 'true' ? true : false;
initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly); initStartupInfo(startMode, experimentId, port, mode, logDir, logLevel, readonly);
mkDirP(getLogDir()) mkDirP(getLogDir())
.then(async () => { .then(async () => {
try { try {
await initContainer(foreground, mode); await initContainer(foreground, mode);
const restServer: NNIRestServer = component.get(NNIRestServer); const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start(); await restServer.start();
const log: Logger = getLogger(); const log: Logger = getLogger();
log.info(`Rest server listening on: ${restServer.endPoint}`); log.info(`Rest server listening on: ${restServer.endPoint}`);
} catch (err) { } catch (err) {
const log: Logger = getLogger(); const log: Logger = getLogger();
log.error(`${err.stack}`); log.error(`${err.stack}`);
throw err; throw err;
} }
}) })
.catch((err: Error) => { .catch((err: Error) => {
console.error(`Failed to create log dir: ${err.stack}`); console.error(`Failed to create log dir: ${err.stack}`);
}); });
function getStopSignal(): any { function getStopSignal(): any {
if (process.platform === "win32") { if (process.platform === "win32") {
return 'SIGBREAK'; return 'SIGBREAK';
} }
else{ else {
return 'SIGTERM'; return 'SIGTERM';
} }
} }
......
...@@ -29,7 +29,8 @@ ...@@ -29,7 +29,8 @@
"ts-deferred": "^1.0.4", "ts-deferred": "^1.0.4",
"typescript-ioc": "^1.2.4", "typescript-ioc": "^1.2.4",
"typescript-string-operations": "^1.3.1", "typescript-string-operations": "^1.3.1",
"webhdfs": "^1.2.0" "webhdfs": "^1.2.0",
"ws": "^7.3.0"
}, },
"devDependencies": { "devDependencies": {
"@types/chai": "^4.1.4", "@types/chai": "^4.1.4",
...@@ -46,6 +47,7 @@ ...@@ -46,6 +47,7 @@
"@types/stream-buffers": "^3.0.2", "@types/stream-buffers": "^3.0.2",
"@types/tar": "^4.0.3", "@types/tar": "^4.0.3",
"@types/tmp": "^0.0.33", "@types/tmp": "^0.0.33",
"@types/ws": "^7.2.5",
"@typescript-eslint/eslint-plugin": "^2.10.0", "@typescript-eslint/eslint-plugin": "^2.10.0",
"@typescript-eslint/parser": "^2.10.0", "@typescript-eslint/parser": "^2.10.0",
"chai": "^4.1.2", "chai": "^4.1.2",
......
...@@ -104,6 +104,7 @@ export namespace ValidationSchemas { ...@@ -104,6 +104,7 @@ export namespace ValidationSchemas {
passWord: joi.string().min(1), passWord: joi.string().min(1),
token: joi.string().min(1), token: joi.string().min(1),
host: joi.string().min(1).required(), host: joi.string().min(1).required(),
reuse: joi.boolean(),
}), }),
kubeflow_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase kubeflow_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
operator: joi.string().min(1).required(), operator: joi.string().min(1).required(),
......
...@@ -16,12 +16,21 @@ export class GPUInfo { ...@@ -16,12 +16,21 @@ export class GPUInfo {
public gpuUtil: number; public gpuUtil: number;
// the index number of this GPU (starting from 0) // the index number of this GPU (starting from 0)
public readonly index: number; public readonly index: number;
public gpuMemTotal: number;
public gpuMemFree: number;
public gpuMemUsed: number;
public gpuType: string;
constructor(activeProcessNum: number, gpuMemUtil: number, gpuUtil: number, index: number) { constructor(activeProcessNum: number, gpuMemUtil: number, gpuUtil: number, index: number,
gpuMemTotal: number, gpuMemFree: number, gpuMemUsed: number, gpuType: string) {
this.activeProcessNum = activeProcessNum; this.activeProcessNum = activeProcessNum;
this.gpuMemUtil = gpuMemUtil; this.gpuMemUtil = gpuMemUtil;
this.gpuUtil = gpuUtil; this.gpuUtil = gpuUtil;
this.index = index; this.index = index;
this.gpuMemTotal = gpuMemTotal;
this.gpuMemFree = gpuMemFree;
this.gpuMemUsed = gpuMemUsed;
this.gpuType = gpuType;
} }
} }
...@@ -44,7 +53,7 @@ export class GPUSummary { ...@@ -44,7 +53,7 @@ export class GPUSummary {
} }
export const GPU_INFO_COLLECTOR_FORMAT_WINDOWS: string = export const GPU_INFO_COLLECTOR_FORMAT_WINDOWS: string =
` `
$env:METRIC_OUTPUT_DIR="{0}" $env:METRIC_OUTPUT_DIR="{0}"
$app = Start-Process "python" -ArgumentList "-m nni_gpu_tool.gpu_metrics_collector" -passthru -NoNewWindow $app = Start-Process "python" -ArgumentList "-m nni_gpu_tool.gpu_metrics_collector" -passthru -NoNewWindow
Write $app.ID | Out-File {1} -NoNewline -encoding utf8 Write $app.ID | Out-File {1} -NoNewline -encoding utf8
......
...@@ -10,6 +10,7 @@ export class PAIClusterConfig { ...@@ -10,6 +10,7 @@ export class PAIClusterConfig {
public readonly passWord?: string; public readonly passWord?: string;
public host: string; public host: string;
public readonly token?: string; public readonly token?: string;
public readonly reuse?: boolean;
/** /**
* Constructor * Constructor
...@@ -17,12 +18,14 @@ export class PAIClusterConfig { ...@@ -17,12 +18,14 @@ export class PAIClusterConfig {
* @param passWord password of PAI Cluster * @param passWord password of PAI Cluster
* @param host Host IP of PAI Cluster * @param host Host IP of PAI Cluster
* @param token PAI token of PAI Cluster * @param token PAI token of PAI Cluster
* @param reuse If job is reusable for multiple trials
*/ */
constructor(userName: string, host: string, passWord?: string, token?: string) { constructor(userName: string, host: string, passWord?: string, token?: string, reuse?: boolean) {
this.userName = userName; this.userName = userName;
this.passWord = passWord; this.passWord = passWord;
this.host = host; this.host = host;
this.token = token; this.token = token;
this.reuse = reuse;
} }
} }
......
/** // Copyright (c) Microsoft Corporation.
* Copyright (c) Microsoft Corporation // Licensed under the MIT license.
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict'; 'use strict';
import {TrialConfig} from '../../common/trialConfig'; import {TrialConfig} from '../../common/trialConfig';
......
/** // Copyright (c) Microsoft Corporation.
* Copyright (c) Microsoft Corporation // Licensed under the MIT license.
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict'; 'use strict';
......
/** // Copyright (c) Microsoft Corporation.
* Copyright (c) Microsoft Corporation // Licensed under the MIT license.
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict'; 'use strict';
......
...@@ -406,12 +406,11 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -406,12 +406,11 @@ class RemoteMachineTrainingService implements TrainingService {
private async setupConnections(machineList: string): Promise<void> { private async setupConnections(machineList: string): Promise<void> {
this.log.debug(`Connecting to remote machines: ${machineList}`); this.log.debug(`Connecting to remote machines: ${machineList}`);
const deferred: Deferred<void> = new Deferred<void>();
//TO DO: verify if value's format is wrong, and json parse failed, how to handle error //TO DO: verify if value's format is wrong, and json parse failed, how to handle error
const rmMetaList: RemoteMachineMeta[] = <RemoteMachineMeta[]>JSON.parse(machineList); const rmMetaList: RemoteMachineMeta[] = <RemoteMachineMeta[]>JSON.parse(machineList);
let connectedRMNum: number = 0;
rmMetaList.forEach(async (rmMeta: RemoteMachineMeta) => { const connectionPromises = [];
for (const rmMeta of rmMetaList) {
rmMeta.occupiedGpuIndexMap = new Map<number, number>(); rmMeta.occupiedGpuIndexMap = new Map<number, number>();
const executorManager: ExecutorManager = new ExecutorManager(rmMeta); const executorManager: ExecutorManager = new ExecutorManager(rmMeta);
this.log.info(`connecting to ${rmMeta.username}@${rmMeta.ip}:${rmMeta.port}`); this.log.info(`connecting to ${rmMeta.username}@${rmMeta.ip}:${rmMeta.port}`);
...@@ -419,14 +418,11 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -419,14 +418,11 @@ class RemoteMachineTrainingService implements TrainingService {
this.log.debug(`reached ${executor.name}`); this.log.debug(`reached ${executor.name}`);
this.machineExecutorManagerMap.set(rmMeta, executorManager); this.machineExecutorManagerMap.set(rmMeta, executorManager);
this.log.debug(`initializing ${executor.name}`); this.log.debug(`initializing ${executor.name}`);
await this.initRemoteMachineOnConnected(rmMeta, executor); connectionPromises.push(this.initRemoteMachineOnConnected(rmMeta, executor));
this.log.info(`connected to ${executor.name}`); this.log.info(`connected to ${executor.name}`);
if (++connectedRMNum === rmMetaList.length) { }
deferred.resolve();
}
});
return deferred.promise; await Promise.all(connectionPromises);
} }
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> { private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, executor: ShellExecutor): Promise<void> {
...@@ -460,7 +456,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -460,7 +456,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.timer.unsubscribe(disposable); this.timer.unsubscribe(disposable);
} }
} }
if (this.stopping){ if (this.stopping) {
this.timer.unsubscribe(disposable); this.timer.unsubscribe(disposable);
this.log.debug(`Stopped GPU collector on ${rmMeta.ip}, since experiment is exiting.`); this.log.debug(`Stopped GPU collector on ${rmMeta.ip}, since experiment is exiting.`);
} }
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as component from "../../../common/component";
import { delay } from "../../../common/utils";
import { CommandChannel, RunnerConnection } from "../commandChannel";
import { EnvironmentInformation, Channel } from "../environment";
import { StorageService } from "../storageService";
class FileHandler {
public fileName: string;
public offset: number = 0;
constructor(fileName: string) {
this.fileName = fileName;
}
}
class FileRunnerConnection extends RunnerConnection {
public handlers: Map<string, FileHandler> = new Map<string, FileHandler>();
}
export class FileCommandChannel extends CommandChannel {
private readonly commandPath = "commands";
private stopping: boolean = false;
// make sure no concurrent issue when sending commands.
private sendQueues: [EnvironmentInformation, string][] = [];
public get channelName(): Channel {
return "file";
}
public async config(_key: string, _value: any): Promise<void> {
// do nothing
}
public async start(): Promise<void> {
// start command loops
this.receiveLoop();
this.sendLoop();
}
public async stop(): Promise<void> {
this.stopping = true;
}
protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void> {
this.sendQueues.push([environment, message]);
}
protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection {
return new FileRunnerConnection(environment);
}
private async sendLoop(): Promise<void> {
const intervalSeconds = 0.5;
while (!this.stopping) {
const start = new Date();
if (this.sendQueues.length > 0) {
const storageService = component.get<StorageService>(StorageService);
while (this.sendQueues.length > 0) {
const item = this.sendQueues.shift();
if (item === undefined) {
break;
}
const environment = item[0];
const message = `${item[1]}\n`;
const fileName = storageService.joinPath(environment.workingFolder, this.commandPath, `manager_commands.txt`);
await storageService.save(message, fileName, true);
}
}
const end = new Date();
const delayMs = intervalSeconds * 1000 - (end.valueOf() - start.valueOf());
if (delayMs > 0) {
await delay(delayMs);
}
}
}
private async receiveLoop(): Promise<void> {
const intervalSeconds = 2;
const storageService = component.get<StorageService>(StorageService);
while (!this.stopping) {
const start = new Date();
const runnerConnections = [...this.runnerConnections.values()] as FileRunnerConnection[];
for (const runnerConnection of runnerConnections) {
const envCommandFolder = storageService.joinPath(runnerConnection.environment.workingFolder, this.commandPath);
// open new command files
if (runnerConnection.handlers.size < runnerConnection.environment.nodeCount) {
// to find all node commands file
const commandFileNames = await storageService.listDirectory(envCommandFolder);
const toAddedFileNames = [];
for (const commandFileName of commandFileNames) {
if (commandFileName.startsWith("runner_commands") && !runnerConnection.handlers.has(commandFileName)) {
toAddedFileNames.push(commandFileName);
}
}
for (const toAddedFileName of toAddedFileNames) {
const fullPath = storageService.joinPath(envCommandFolder, toAddedFileName);
const fileHandler: FileHandler = new FileHandler(fullPath);
runnerConnection.handlers.set(toAddedFileName, fileHandler);
this.log.debug(`FileCommandChannel: added fileHandler env ${runnerConnection.environment.id} ${toAddedFileName}`);
}
}
// to loop all commands
for (const fileHandler of runnerConnection.handlers.values()) {
const newContent = await storageService.readFileContent(fileHandler.fileName, fileHandler.offset, undefined);
if (newContent.length > 0) {
const commands = newContent.split('\n');
for (const command of commands) {
this.handleCommand(runnerConnection.environment, command);
}
fileHandler.offset += newContent.length;
}
}
}
const end = new Date();
const delayMs = intervalSeconds * 1000 - (end.valueOf() - start.valueOf());
if (delayMs > 0) {
await delay(delayMs);
}
}
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { Server as SocketServer } from "ws";
import { getBasePort, getExperimentId } from "../../../common/experimentStartupInfo";
import { INITIALIZED } from '../../../core/commands';
import { CommandChannel, RunnerConnection } from "../commandChannel";
import { Channel, EnvironmentInformation } from "../environment";
class WebRunnerConnection extends RunnerConnection {
public readonly clients: WebSocket[] = [];
public async close(): Promise<void> {
await super.close();
while (this.clients.length > 0) {
const client = this.clients.shift();
if (client !== undefined) {
client.close();
}
}
}
public AddClient(client: WebSocket): void {
this.clients.push(client);
}
}
export class WebCommandChannel extends CommandChannel {
private readonly expId: string = getExperimentId();
private webSocketServer: SocketServer | undefined;
private clients: Map<WebSocket, WebRunnerConnection | undefined> = new Map<WebSocket, WebRunnerConnection | undefined>();
public get channelName(): Channel {
return "web";
}
public async config(_key: string, _value: any): Promise<void> {
// do nothing
}
public async start(): Promise<void> {
const port = getBasePort() + 1;
this.webSocketServer = new SocketServer({ port });
this.webSocketServer.on('connection', (client: WebSocket) => {
this.log.debug(`WebCommandChannel: received connection`);
client.onerror = (event): void => {
this.log.error(`error on client ${JSON.stringify(event)}`);
}
this.clients.set(client, undefined);
client.onmessage = (message): void => {
this.receivedWebSocketMessage(client, message);
};
}).on('error', (error) => {
this.log.error(`error on websocket server ${error}`);
});
}
public async stop(): Promise<void> {
if (this.webSocketServer !== undefined) {
this.webSocketServer.close();
}
}
protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void> {
if (this.webSocketServer === undefined) {
throw new Error(`WebCommandChannel: uninitialized!`)
}
const runnerConnection = this.runnerConnections.get(environment.id) as WebRunnerConnection;
if (runnerConnection !== undefined) {
for (const client of runnerConnection.clients) {
client.send(message);
}
} else {
this.log.warning(`WebCommandChannel: cannot find client for env ${environment.id}, message is ignored.`);
}
}
protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection {
return new WebRunnerConnection(environment);
}
private receivedWebSocketMessage(client: WebSocket, message: MessageEvent): void {
let connection = this.clients.get(client) as WebRunnerConnection | undefined;
const rawCommands = message.data.toString();
if (connection === undefined) {
// undefined means it's expecting initializing message.
const commands = this.parseCommands(rawCommands);
let isValid = false;
this.log.debug(`WebCommandChannel: received initialize message: ${JSON.stringify(rawCommands)}`);
if (commands.length > 0) {
const commandType = commands[0][0];
const result = commands[0][1];
if (commandType === INITIALIZED &&
result.expId === this.expId &&
this.runnerConnections.has(result.runnerId)
) {
const runnerConnection = this.runnerConnections.get(result.runnerId) as WebRunnerConnection;
this.clients.set(client, runnerConnection);
runnerConnection.AddClient(client);
connection = runnerConnection;
isValid = true;
this.log.debug(`WebCommandChannel: client of env ${runnerConnection.environment.id} initialized`);
} else {
this.log.warning(`WebCommandChannel: client is not initialized, runnerId: ${result.runnerId}, command: ${commandType}, expId: ${this.expId}, exists: ${this.runnerConnections.has(result.runnerId)}`);
}
}
if (!isValid) {
this.log.warning(`WebCommandChannel: rejected client with invalid init message ${rawCommands}`);
client.close();
this.clients.delete(client);
}
}
if (connection !== undefined) {
this.handleCommand(connection.environment, rawCommands);
}
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { EventEmitter } from "events";
import { getLogger, Logger } from "../../common/log";
import { TRIAL_COMMANDS } from "../../core/commands";
import { encodeCommand } from "../../core/ipcInterface";
import { Channel, EnvironmentInformation } from "./environment";
const acceptedCommands: Set<string> = new Set<string>(TRIAL_COMMANDS);
export class Command {
public readonly environment: EnvironmentInformation;
public readonly command: string;
public readonly data: any;
constructor(environment: EnvironmentInformation, command: string, data: any) {
if (!acceptedCommands.has(command)) {
throw new Error(`unaccepted command ${command}`);
}
this.environment = environment;
this.command = command;
this.data = data;
}
}
export class RunnerConnection {
public readonly environment: EnvironmentInformation;
constructor(environment: EnvironmentInformation) {
this.environment = environment;
}
public async open(): Promise<void> {
// do nothing
}
public async close(): Promise<void> {
// do nothing
}
}
export abstract class CommandChannel {
protected readonly log: Logger;
protected runnerConnections: Map<string, RunnerConnection> = new Map<string, RunnerConnection>();
protected readonly commandEmitter: EventEmitter;
private readonly commandPattern: RegExp = /(?<type>[\w]{2})(?<length>[\d]{14})(?<data>.*)\n?/gm;
public constructor(commandEmitter: EventEmitter) {
this.log = getLogger();
this.commandEmitter = commandEmitter;
}
public abstract get channelName(): Channel;
public abstract config(key: string, value: any): Promise<void>;
public abstract start(): Promise<void>;
public abstract stop(): Promise<void>;
protected abstract sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void>;
protected abstract createRunnerConnection(environment: EnvironmentInformation): RunnerConnection;
public async sendCommand(environment: EnvironmentInformation, commantType: string, data: any): Promise<void> {
const command = encodeCommand(commantType, JSON.stringify(data));
this.log.debug(`CommandChannel: env ${environment.id} sending command: ${command}`);
await this.sendCommandInternal(environment, command.toString("utf8"));
}
public async open(environment: EnvironmentInformation): Promise<void> {
if (this.runnerConnections.has(environment.id)) {
throw new Error(`CommandChannel: env ${environment.id} is opened already, shouldn't be opened again.`);
}
const connection = this.createRunnerConnection(environment);
this.runnerConnections.set(environment.id, connection);
await connection.open();
}
public async close(environment: EnvironmentInformation): Promise<void> {
if (this.runnerConnections.has(environment.id)) {
const connection = this.runnerConnections.get(environment.id);
this.runnerConnections.delete(environment.id);
if (connection !== undefined) {
await connection.close();
}
}
}
protected parseCommands(content: string): [string, any][] {
const commands: [string, any][] = [];
let matches = this.commandPattern.exec(content);
while (matches) {
if (undefined !== matches.groups) {
const commandType = matches.groups["type"];
const dataLength = parseInt(matches.groups["length"]);
const data: any = matches.groups["data"];
if (dataLength !== data.length) {
throw new Error(`dataLength ${dataLength} not equal to actual length ${data.length}: ${data}`);
}
try {
const finalData = JSON.parse(data);
// to handle encode('utf8') of Python
commands.push([commandType, finalData]);
} catch (error) {
this.log.error(`CommandChannel: error on parseCommands ${error}, original: ${matches.groups["data"]}`);
throw error;
}
}
matches = this.commandPattern.exec(content);
}
return commands;
}
protected handleCommand(environment: EnvironmentInformation, content: string): void {
const parsedResults = this.parseCommands(content);
for (const parsedResult of parsedResults) {
const commandType = parsedResult[0];
const data = parsedResult[1];
const command = new Command(environment, commandType, data);
this.commandEmitter.emit("command", command);
this.log.trace(`CommandChannel: env ${environment.id} emit command: ${commandType}, ${data}.`);
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment