Unverified Commit 543239c6 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #220 from microsoft/master

merge master
parents 32efaa36 659480f2
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
'use strict'; 'use strict';
// tslint:disable-next-line:no-implicit-dependencies
import * as request from 'request'; import * as request from 'request';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
...@@ -16,10 +15,10 @@ import { PAITrialJobDetail } from './paiData'; ...@@ -16,10 +15,10 @@ import { PAITrialJobDetail } from './paiData';
* Collector PAI jobs info from PAI cluster, and update pai job status locally * Collector PAI jobs info from PAI cluster, and update pai job status locally
*/ */
export class PAIJobInfoCollector { export class PAIJobInfoCollector {
private readonly trialJobsMap : Map<string, PAITrialJobDetail>; private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
private readonly log: Logger = getLogger(); private readonly log: Logger = getLogger();
private readonly statusesNeedToCheck : TrialJobStatus[]; private readonly statusesNeedToCheck: TrialJobStatus[];
private readonly finalStatuses : TrialJobStatus[]; private readonly finalStatuses: TrialJobStatus[];
constructor(jobMap: Map<string, PAITrialJobDetail>) { constructor(jobMap: Map<string, PAITrialJobDetail>) {
this.trialJobsMap = jobMap; this.trialJobsMap = jobMap;
...@@ -27,12 +26,12 @@ export class PAIJobInfoCollector { ...@@ -27,12 +26,12 @@ export class PAIJobInfoCollector {
this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED']; this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'];
} }
public async retrieveTrialStatus(paiToken? : string, paiClusterConfig?: PAIClusterConfig) : Promise<void> { public async retrieveTrialStatus(paiToken? : string, paiClusterConfig?: PAIClusterConfig): Promise<void> {
if (paiClusterConfig === undefined || paiToken === undefined) { if (paiClusterConfig === undefined || paiToken === undefined) {
return Promise.resolve(); return Promise.resolve();
} }
const updatePaiTrialJobs : Promise<void>[] = []; const updatePaiTrialJobs: Promise<void>[] = [];
for (const [trialJobId, paiTrialJob] of this.trialJobsMap) { for (const [trialJobId, paiTrialJob] of this.trialJobsMap) {
if (paiTrialJob === undefined) { if (paiTrialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`); throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`);
...@@ -43,9 +42,8 @@ export class PAIJobInfoCollector { ...@@ -43,9 +42,8 @@ export class PAIJobInfoCollector {
await Promise.all(updatePaiTrialJobs); await Promise.all(updatePaiTrialJobs);
} }
private getSinglePAITrialJobInfo(paiTrialJob : PAITrialJobDetail, paiToken : string, paiClusterConfig: PAIClusterConfig) private getSinglePAITrialJobInfo(paiTrialJob: PAITrialJobDetail, paiToken: string, paiClusterConfig: PAIClusterConfig): Promise<void> {
: Promise<void> { const deferred: Deferred<void> = new Deferred<void>();
const deferred : Deferred<void> = new Deferred<void>();
if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) { if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) {
deferred.resolve(); deferred.resolve();
...@@ -55,7 +53,6 @@ export class PAIJobInfoCollector { ...@@ -55,7 +53,6 @@ export class PAIJobInfoCollector {
// Rest call to get PAI job info and update status // Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const getJobInfoRequest: request.Options = { const getJobInfoRequest: request.Options = {
// tslint:disable-next-line:no-http-string
uri: `http://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`, uri: `http://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`,
method: 'GET', method: 'GET',
json: true, json: true,
...@@ -65,7 +62,6 @@ export class PAIJobInfoCollector { ...@@ -65,7 +62,6 @@ export class PAIJobInfoCollector {
} }
}; };
// tslint:disable: no-unsafe-any no-any cyclomatic-complexity
//TODO : pass in request timeout param? //TODO : pass in request timeout param?
request(getJobInfoRequest, (error: Error, response: request.Response, body: any) => { request(getJobInfoRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 500) { if ((error !== undefined && error !== null) || response.statusCode >= 500) {
...@@ -129,5 +125,4 @@ export class PAIJobInfoCollector { ...@@ -129,5 +125,4 @@ export class PAIJobInfoCollector {
return deferred.promise; return deferred.promise;
} }
// tslint:enable: no-unsafe-any no-any
} }
...@@ -24,7 +24,7 @@ export class PAIJobRestServer extends ClusterJobRestServer { ...@@ -24,7 +24,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
private parameterFileMetaList: ParameterFileMeta[] = []; private parameterFileMetaList: ParameterFileMeta[] = [];
@Inject @Inject
private readonly paiTrainingService : PAITrainingService; private readonly paiTrainingService: PAITrainingService;
/** /**
* constructor to provide NNIRestServer's own rest property, e.g. port * constructor to provide NNIRestServer's own rest property, e.g. port
...@@ -34,8 +34,7 @@ export class PAIJobRestServer extends ClusterJobRestServer { ...@@ -34,8 +34,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
this.paiTrainingService = component.get(PAITrainingService); this.paiTrainingService = component.get(PAITrainingService);
} }
// tslint:disable-next-line:no-any protected handleTrialMetrics(jobId: string, metrics: any[]): void {
protected handleTrialMetrics(jobId : string, metrics : any[]) : void {
// Split metrics array into single metric, then emit // Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWN // Warning: If not split metrics into single ones, the behavior will be UNKNOWN
for (const singleMetric of metrics) { for (const singleMetric of metrics) {
......
...@@ -3,17 +3,14 @@ ...@@ -3,17 +3,14 @@
'use strict'; 'use strict';
import * as cpp from 'child-process-promise';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
// tslint:disable-next-line:no-implicit-dependencies
import * as request from 'request'; import * as request from 'request';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { MethodNotImplementedError } from '../../common/errors';
import { getExperimentId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { import {
...@@ -47,13 +44,12 @@ class PAITrainingService implements TrainingService { ...@@ -47,13 +44,12 @@ class PAITrainingService implements TrainingService {
private paiClusterConfig?: PAIClusterConfig; private paiClusterConfig?: PAIClusterConfig;
private readonly jobQueue: string[]; private readonly jobQueue: string[];
private stopping: boolean = false; private stopping: boolean = false;
// tslint:disable-next-line:no-any
private hdfsClient: any; private hdfsClient: any;
private paiToken? : string; private paiToken? : string;
private paiTokenUpdateTime?: number; private paiTokenUpdateTime?: number;
private readonly paiTokenUpdateInterval: number; private readonly paiTokenUpdateInterval: number;
private readonly experimentId! : string; private readonly experimentId!: string;
private readonly paiJobCollector : PAIJobInfoCollector; private readonly paiJobCollector: PAIJobInfoCollector;
private paiRestServerPort?: number; private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig; private nniManagerIpConfig?: NNIManagerIpConfig;
private copyExpCodeDirPromise?: Promise<void>; private copyExpCodeDirPromise?: Promise<void>;
...@@ -126,7 +122,7 @@ class PAITrainingService implements TrainingService { ...@@ -126,7 +122,7 @@ class PAITrainingService implements TrainingService {
if (this.paiClusterConfig === undefined) { if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`); throw new Error(`paiClusterConfig not initialized!`);
} }
const deferred : Deferred<PAITrialJobDetail> = new Deferred<PAITrialJobDetail>(); const deferred: Deferred<PAITrialJobDetail> = new Deferred<PAITrialJobDetail>();
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`); this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
...@@ -137,7 +133,7 @@ class PAITrainingService implements TrainingService { ...@@ -137,7 +133,7 @@ class PAITrainingService implements TrainingService {
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId); const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput'); const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
const hdfsLogPath : string = String.Format( const hdfsLogPath: string = String.Format(
PAI_LOG_PATH_FORMAT, PAI_LOG_PATH_FORMAT,
this.paiClusterConfig.host, this.paiClusterConfig.host,
hdfsOutputDir hdfsOutputDir
...@@ -173,10 +169,9 @@ class PAITrainingService implements TrainingService { ...@@ -173,10 +169,9 @@ class PAITrainingService implements TrainingService {
return true; return true;
} }
// tslint:disable:no-http-string
public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJobDetail : PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
const deferred : Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
this.log.error(`cancelTrialJob: trial job id ${trialJobId} not found`); this.log.error(`cancelTrialJob: trial job id ${trialJobId} not found`);
...@@ -205,7 +200,6 @@ class PAITrainingService implements TrainingService { ...@@ -205,7 +200,6 @@ class PAITrainingService implements TrainingService {
// Set trialjobDetail's early stopped field, to mark the job's cancellation source // Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail.isEarlyStopped = isEarlyStopped; trialJobDetail.isEarlyStopped = isEarlyStopped;
// tslint:disable-next-line:no-any
request(stopJobRequest, (error: Error, response: request.Response, body: any) => { request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) { if ((error !== undefined && error !== null) || response.statusCode >= 400) {
this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`); this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
...@@ -219,10 +213,8 @@ class PAITrainingService implements TrainingService { ...@@ -219,10 +213,8 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
// tslint:disable: no-unsafe-any no-any
// tslint:disable-next-line:max-func-body-length
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
switch (key) { switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP: case TrialConfigMetadataKey.NNI_MANAGER_IP:
...@@ -300,10 +292,9 @@ class PAITrainingService implements TrainingService { ...@@ -300,10 +292,9 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
// tslint:enable: no-unsafe-any
public getClusterMetadata(key: string): Promise<string> { public getClusterMetadata(key: string): Promise<string> {
const deferred : Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
deferred.resolve(); deferred.resolve();
...@@ -314,14 +305,13 @@ class PAITrainingService implements TrainingService { ...@@ -314,14 +305,13 @@ class PAITrainingService implements TrainingService {
this.log.info('Stopping PAI training service...'); this.log.info('Stopping PAI training service...');
this.stopping = true; this.stopping = true;
const deferred : Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer); const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
try { try {
await restServer.stop(); await restServer.stop();
deferred.resolve(); deferred.resolve();
this.log.info('PAI Training service rest server stopped successfully.'); this.log.info('PAI Training service rest server stopped successfully.');
} catch (error) { } catch (error) {
// tslint:disable-next-line: no-unsafe-any
this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`); this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
deferred.reject(error); deferred.reject(error);
} }
...@@ -329,13 +319,12 @@ class PAITrainingService implements TrainingService { ...@@ -329,13 +319,12 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
public get MetricsEmitter() : EventEmitter { public get MetricsEmitter(): EventEmitter {
return this.metricsEmitter; return this.metricsEmitter;
} }
// tslint:disable-next-line:max-func-body-length
private async submitTrialJobToPAI(trialJobId: string): Promise<boolean> { private async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
...@@ -372,7 +361,7 @@ class PAITrainingService implements TrainingService { ...@@ -372,7 +361,7 @@ class PAITrainingService implements TrainingService {
//create tmp trial working folder locally. //create tmp trial working folder locally.
await execMkdir(trialLocalTempFolder); await execMkdir(trialLocalTempFolder);
const runScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT; const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files // Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
...@@ -385,10 +374,9 @@ class PAITrainingService implements TrainingService { ...@@ -385,10 +374,9 @@ class PAITrainingService implements TrainingService {
} }
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId); const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput'); const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
// tslint:disable-next-line: strict-boolean-expressions
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : ''; const version: string = this.versionCheck ? await getVersion() : '';
const nniPaiTrialCommand : string = String.Format( const nniPaiTrialCommand: string = String.Format(
PAI_TRIAL_COMMAND_FORMAT, PAI_TRIAL_COMMAND_FORMAT,
// PAI will copy job's codeDir into /root directory // PAI will copy job's codeDir into /root directory
`$PWD/${trialJobId}`, `$PWD/${trialJobId}`,
...@@ -409,9 +397,8 @@ class PAITrainingService implements TrainingService { ...@@ -409,9 +397,8 @@ class PAITrainingService implements TrainingService {
) )
.replace(/\r\n|\n|\r/gm, ''); .replace(/\r\n|\n|\r/gm, '');
// tslint:disable-next-line:no-console
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`); this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiTaskRoles : PAITaskRole[] = [ const paiTaskRoles: PAITaskRole[] = [
new PAITaskRole( new PAITaskRole(
`nni_trail_${trialJobId}`, `nni_trail_${trialJobId}`,
// Task role number // Task role number
...@@ -431,7 +418,7 @@ class PAITrainingService implements TrainingService { ...@@ -431,7 +418,7 @@ class PAITrainingService implements TrainingService {
) )
]; ];
const paiJobConfig : PAIJobConfig = new PAIJobConfig( const paiJobConfig: PAIJobConfig = new PAIJobConfig(
// Job name // Job name
trialJobDetail.paiJobName, trialJobDetail.paiJobName,
// Docker image // Docker image
...@@ -451,7 +438,7 @@ class PAITrainingService implements TrainingService { ...@@ -451,7 +438,7 @@ class PAITrainingService implements TrainingService {
await HDFSClientUtility.copyDirectoryToHdfs(trialLocalTempFolder, hdfsCodeDir, this.hdfsClient); await HDFSClientUtility.copyDirectoryToHdfs(trialLocalTempFolder, hdfsCodeDir, this.hdfsClient);
} catch (error) { } catch (error) {
this.log.error(`PAI Training service: copy ${this.paiTrialConfig.codeDir} to HDFS ${hdfsCodeDir} failed, error is ${error}`); this.log.error(`PAI Training service: copy ${this.paiTrialConfig.codeDir} to HDFS ${hdfsCodeDir} failed, error is ${error}`);
trialJobDetail.status = 'FAILED'; trialJobDetail.status = 'FAILED'; // eslint-disable-line require-atomic-updates
deferred.resolve(true); deferred.resolve(true);
return deferred.promise; return deferred.promise;
...@@ -469,10 +456,9 @@ class PAITrainingService implements TrainingService { ...@@ -469,10 +456,9 @@ class PAITrainingService implements TrainingService {
Authorization: `Bearer ${this.paiToken}` Authorization: `Bearer ${this.paiToken}`
} }
}; };
// tslint:disable:no-any no-unsafe-any
request(submitJobRequest, (error: Error, response: request.Response, body: any) => { request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) { if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage : string = (error !== undefined && error !== null) ? error.message : const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body.message}`; `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body.message}`;
trialJobDetail.status = 'FAILED'; trialJobDetail.status = 'FAILED';
deferred.resolve(true); deferred.resolve(true);
...@@ -527,7 +513,7 @@ class PAITrainingService implements TrainingService { ...@@ -527,7 +513,7 @@ class PAITrainingService implements TrainingService {
* Update pai token by the interval time or initialize the pai token * Update pai token by the interval time or initialize the pai token
*/ */
private async updatePaiToken(): Promise<void> { private async updatePaiToken(): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const currentTime: number = new Date().getTime(); const currentTime: number = new Date().getTime();
//If pai token initialized and not reach the interval time, do not update //If pai token initialized and not reach the interval time, do not update
...@@ -603,7 +589,7 @@ class PAITrainingService implements TrainingService { ...@@ -603,7 +589,7 @@ class PAITrainingService implements TrainingService {
} }
private postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> { private postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer); const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
const req: request.Options = { const req: request.Options = {
uri: `${restServer.endPoint}${restServer.apiRootUrl}/parameter-file-meta`, uri: `${restServer.endPoint}${restServer.apiRootUrl}/parameter-file-meta`,
......
...@@ -15,7 +15,7 @@ export class PAITrialConfig extends TrialConfig { ...@@ -15,7 +15,7 @@ export class PAITrialConfig extends TrialConfig {
public readonly dataDir: string; public readonly dataDir: string;
public readonly outputDir: string; public readonly outputDir: string;
constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number, constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, dataDir: string, outputDir: string) { image: string, dataDir: string, outputDir: string) {
super(command, codeDir, gpuNum); super(command, codeDir, gpuNum);
this.cpuNum = cpuNum; this.cpuNum = cpuNum;
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import * as assert from 'assert'; import * as assert from 'assert';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { TrialJobDetail } from '../../common/trainingService';
import { randomSelect } from '../../common/utils'; import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData'; import { GPUInfo } from '../common/gpuData';
import { import {
...@@ -19,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; ...@@ -19,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/ */
export class GPUScheduler { export class GPUScheduler {
private readonly machineSSHClientMap : Map<RemoteMachineMeta, SSHClientManager>; private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>;
private readonly log: Logger = getLogger(); private readonly log: Logger = getLogger();
private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin'; private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin';
private roundRobinIndex: number = 0; private roundRobinIndex: number = 0;
...@@ -29,7 +28,7 @@ export class GPUScheduler { ...@@ -29,7 +28,7 @@ export class GPUScheduler {
* Constructor * Constructor
* @param machineSSHClientMap map from remote machine to sshClient * @param machineSSHClientMap map from remote machine to sshClient
*/ */
constructor(machineSSHClientMap : Map<RemoteMachineMeta, SSHClientManager>) { constructor(machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>) {
assert(machineSSHClientMap.size > 0); assert(machineSSHClientMap.size > 0);
this.machineSSHClientMap = machineSSHClientMap; this.machineSSHClientMap = machineSSHClientMap;
this.configuredRMs = Array.from(machineSSHClientMap.keys()); this.configuredRMs = Array.from(machineSSHClientMap.keys());
...@@ -39,7 +38,7 @@ export class GPUScheduler { ...@@ -39,7 +38,7 @@ export class GPUScheduler {
* Schedule a machine according to the constraints (requiredGPUNum) * Schedule a machine according to the constraints (requiredGPUNum)
* @param requiredGPUNum required GPU number * @param requiredGPUNum required GPU number
*/ */
public scheduleMachine(requiredGPUNum: number | undefined, trialJobDetail : RemoteMachineTrialJobDetail) : RemoteMachineScheduleResult { public scheduleMachine(requiredGPUNum: number | undefined, trialJobDetail: RemoteMachineTrialJobDetail): RemoteMachineScheduleResult {
if(requiredGPUNum === undefined) { if(requiredGPUNum === undefined) {
requiredGPUNum = 0; requiredGPUNum = 0;
} }
...@@ -48,7 +47,7 @@ export class GPUScheduler { ...@@ -48,7 +47,7 @@ export class GPUScheduler {
assert(allRMs.length > 0); assert(allRMs.length > 0);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines // Step 1: Check if required GPU number not exceeds the total GPU number in all machines
const eligibleRM: RemoteMachineMeta[] = allRMs.filter((rmMeta : RemoteMachineMeta) => const eligibleRM: RemoteMachineMeta[] = allRMs.filter((rmMeta: RemoteMachineMeta) =>
rmMeta.gpuSummary === undefined || requiredGPUNum === 0 || (requiredGPUNum !== undefined && rmMeta.gpuSummary.gpuCount >= requiredGPUNum)); rmMeta.gpuSummary === undefined || requiredGPUNum === 0 || (requiredGPUNum !== undefined && rmMeta.gpuSummary.gpuCount >= requiredGPUNum));
if (eligibleRM.length === 0) { if (eligibleRM.length === 0) {
// If the required gpu number exceeds the upper limit of all machine's GPU number // If the required gpu number exceeds the upper limit of all machine's GPU number
...@@ -134,8 +133,8 @@ export class GPUScheduler { ...@@ -134,8 +133,8 @@ export class GPUScheduler {
* @param availableGPUMap available GPU resource filled by this detection * @param availableGPUMap available GPU resource filled by this detection
* @returns Available GPU number on this remote machine * @returns Available GPU number on this remote machine
*/ */
private gpuResourceDetection() : Map<RemoteMachineMeta, GPUInfo[]> { private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap : Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>(); const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
this.machineSSHClientMap.forEach((sshClientManager: SSHClientManager, rmMeta: RemoteMachineMeta) => { this.machineSSHClientMap.forEach((sshClientManager: SSHClientManager, rmMeta: RemoteMachineMeta) => {
// Assgin totoal GPU count as init available GPU number // Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) { if (rmMeta.gpuSummary !== undefined) {
...@@ -149,7 +148,6 @@ export class GPUScheduler { ...@@ -149,7 +148,6 @@ export class GPUScheduler {
} }
} }
this.log.debug(`designated gpu indices: ${designatedGpuIndices}`); this.log.debug(`designated gpu indices: ${designatedGpuIndices}`);
// tslint:disable: strict-boolean-expressions
rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => { rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => {
// if the GPU has active process, OR be reserved by a job, // if the GPU has active process, OR be reserved by a job,
// or index not in gpuIndices configuration in machineList, // or index not in gpuIndices configuration in machineList,
...@@ -175,7 +173,6 @@ export class GPUScheduler { ...@@ -175,7 +173,6 @@ export class GPUScheduler {
return totalResourceMap; return totalResourceMap;
} }
// tslint:enable: strict-boolean-expressions
private selectMachine(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta { private selectMachine(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta {
assert(rmMetas !== undefined && rmMetas.length > 0); assert(rmMetas !== undefined && rmMetas.length > 0);
...@@ -224,11 +221,11 @@ export class GPUScheduler { ...@@ -224,11 +221,11 @@ export class GPUScheduler {
resultType: ScheduleResultType.SUCCEED, resultType: ScheduleResultType.SUCCEED,
scheduleInfo: { scheduleInfo: {
rmMeta: rmMeta, rmMeta: rmMeta,
cuda_visible_device: allocatedGPUs cudaVisibleDevice: allocatedGPUs
.map((gpuInfo: GPUInfo) => { .map((gpuInfo: GPUInfo) => {
return gpuInfo.index; return gpuInfo.index;
}) })
.join(',') .join(',')
} }
}; };
} }
......
...@@ -13,13 +13,13 @@ import { GPUInfo, GPUSummary } from '../common/gpuData'; ...@@ -13,13 +13,13 @@ import { GPUInfo, GPUSummary } from '../common/gpuData';
* Metadata of remote machine for configuration and statuc query * Metadata of remote machine for configuration and statuc query
*/ */
export class RemoteMachineMeta { export class RemoteMachineMeta {
public readonly ip : string = ''; public readonly ip: string = '';
public readonly port : number = 22; public readonly port: number = 22;
public readonly username : string = ''; public readonly username: string = '';
public readonly passwd: string = ''; public readonly passwd: string = '';
public readonly sshKeyPath?: string; public readonly sshKeyPath?: string;
public readonly passphrase?: string; public readonly passphrase?: string;
public gpuSummary : GPUSummary | undefined; public gpuSummary: GPUSummary | undefined;
public readonly gpuIndices?: string; public readonly gpuIndices?: string;
public readonly maxTrialNumPerGpu?: number; public readonly maxTrialNumPerGpu?: number;
//TODO: initialize varialbe in constructor //TODO: initialize varialbe in constructor
...@@ -43,11 +43,11 @@ export function parseGpuIndices(gpuIndices?: string): Set<number> | undefined { ...@@ -43,11 +43,11 @@ export function parseGpuIndices(gpuIndices?: string): Set<number> | undefined {
* The execution result for command executed on remote machine * The execution result for command executed on remote machine
*/ */
export class RemoteCommandResult { export class RemoteCommandResult {
public readonly stdout : string; public readonly stdout: string;
public readonly stderr : string; public readonly stderr: string;
public readonly exitCode : number; public readonly exitCode: number;
constructor(stdout : string, stderr : string, exitCode : number) { constructor(stdout: string, stderr: string, exitCode: number) {
this.stdout = stdout; this.stdout = stdout;
this.stderr = stderr; this.stderr = stderr;
this.exitCode = exitCode; this.exitCode = exitCode;
...@@ -186,7 +186,6 @@ export class SSHClientManager { ...@@ -186,7 +186,6 @@ export class SSHClientManager {
/** /**
* Create a new ssh connection client and initialize it * Create a new ssh connection client and initialize it
*/ */
// tslint:disable:non-literal-fs-path
private initNewSSHClient(): Promise<Client> { private initNewSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>(); const deferred: Deferred<Client> = new Deferred<Client>();
const conn: Client = new Client(); const conn: Client = new Client();
...@@ -225,9 +224,9 @@ export class SSHClientManager { ...@@ -225,9 +224,9 @@ export class SSHClientManager {
} }
} }
export type RemoteMachineScheduleResult = { scheduleInfo : RemoteMachineScheduleInfo | undefined; resultType : ScheduleResultType}; export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType};
export type RemoteMachineScheduleInfo = { rmMeta : RemoteMachineMeta; cuda_visible_device : string}; export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string};
export enum ScheduleResultType { export enum ScheduleResultType {
// Schedule succeeded // Schedule succeeded
......
...@@ -15,7 +15,7 @@ import { RemoteMachineTrainingService } from './remoteMachineTrainingService'; ...@@ -15,7 +15,7 @@ import { RemoteMachineTrainingService } from './remoteMachineTrainingService';
@component.Singleton @component.Singleton
export class RemoteMachineJobRestServer extends ClusterJobRestServer { export class RemoteMachineJobRestServer extends ClusterJobRestServer {
@Inject @Inject
private readonly remoteMachineTrainingService : RemoteMachineTrainingService; private readonly remoteMachineTrainingService: RemoteMachineTrainingService;
/** /**
* constructor to provide NNIRestServer's own rest property, e.g. port * constructor to provide NNIRestServer's own rest property, e.g. port
...@@ -25,8 +25,7 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer { ...@@ -25,8 +25,7 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer {
this.remoteMachineTrainingService = component.get(RemoteMachineTrainingService); this.remoteMachineTrainingService = component.get(RemoteMachineTrainingService);
} }
// tslint:disable-next-line:no-any protected handleTrialMetrics(jobId: string, metrics: any[]): void {
protected handleTrialMetrics(jobId : string, metrics : any[]) : void {
// Split metrics array into single metric, then emit // Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWNls // Warning: If not split metrics into single ones, the behavior will be UNKNOWNls
for (const singleMetric of metrics) { for (const singleMetric of metrics) {
......
...@@ -4,12 +4,10 @@ ...@@ -4,12 +4,10 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import * as cpp from 'child-process-promise';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import * as fs from 'fs'; import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import { Client, ConnectConfig } from 'ssh2'; import { Client } from 'ssh2';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import * as component from '../../common/component'; import * as component from '../../common/component';
...@@ -29,12 +27,12 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; ...@@ -29,12 +27,12 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { GPUSummary } from '../common/gpuData'; import { GPUSummary } from '../common/gpuData';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execCopydir, execMkdir, execRemove, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util'; import { execCopydir, execMkdir, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util';
import { GPUScheduler } from './gpuScheduler'; import { GPUScheduler } from './gpuScheduler';
import { import {
HOST_JOB_SHELL_FORMAT, RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta, RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta,
RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail, RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail,
ScheduleResultType, SSHClient, SSHClientManager ScheduleResultType, SSHClientManager
} from './remoteMachineData'; } from './remoteMachineData';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer'; import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { SSHClientUtility } from './sshClientUtility'; import { SSHClientUtility } from './sshClientUtility';
...@@ -93,7 +91,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -93,7 +91,7 @@ class RemoteMachineTrainingService implements TrainingService {
while (this.jobQueue.length > 0) { while (this.jobQueue.length > 0) {
this.updateGpuReservation(); this.updateGpuReservation();
const trialJobId: string = this.jobQueue[0]; const trialJobId: string = this.jobQueue[0];
const prepareResult : boolean = await this.prepareTrialJob(trialJobId); const prepareResult: boolean = await this.prepareTrialJob(trialJobId);
if (prepareResult) { if (prepareResult) {
// Remove trial job with trialJobId from job queue // Remove trial job with trialJobId from job queue
this.jobQueue.shift(); this.jobQueue.shift();
...@@ -208,7 +206,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -208,7 +206,6 @@ class RemoteMachineTrainingService implements TrainingService {
* Submit trial job * Submit trial job
* @param form trial job description form * @param form trial job description form
*/ */
// tslint:disable-next-line:informative-docs
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> { public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.trialConfig === undefined) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
...@@ -241,12 +238,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -241,12 +238,7 @@ class RemoteMachineTrainingService implements TrainingService {
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`); throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
} }
const rmMeta: RemoteMachineMeta | undefined = (<RemoteMachineTrialJobDetail>trialJobDetail).rmMeta; await this.writeParameterFile(trialJobId, form.hyperParameters);
if (rmMeta !== undefined) {
await this.writeParameterFile(trialJobId, form.hyperParameters, rmMeta);
} else {
throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
}
return trialJobDetail; return trialJobDetail;
} }
...@@ -262,7 +254,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -262,7 +254,6 @@ class RemoteMachineTrainingService implements TrainingService {
* Cancel trial job * Cancel trial job
* @param trialJobId ID of trial job * @param trialJobId ID of trial job
*/ */
// tslint:disable:informative-docs no-unsafe-any
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
...@@ -272,7 +263,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -272,7 +263,7 @@ class RemoteMachineTrainingService implements TrainingService {
} }
// Remove the job with trialJobId from job queue // Remove the job with trialJobId from job queue
const index : number = this.jobQueue.indexOf(trialJobId); const index: number = this.jobQueue.indexOf(trialJobId);
if (index >= 0) { if (index >= 0) {
this.jobQueue.splice(index, 1); this.jobQueue.splice(index, 1);
} }
...@@ -319,14 +310,13 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -319,14 +310,13 @@ class RemoteMachineTrainingService implements TrainingService {
await this.setupConnections(value); await this.setupConnections(value);
this.gpuScheduler = new GPUScheduler(this.machineSSHClientMap); this.gpuScheduler = new GPUScheduler(this.machineSSHClientMap);
break; break;
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG: {
const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value); const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value);
// Parse trial config failed, throw Error // Parse trial config failed, throw Error
if (remoteMachineTrailConfig === undefined) { if (remoteMachineTrailConfig === undefined) {
throw new Error('trial config parsed failed'); throw new Error('trial config parsed failed');
} }
// codeDir is not a valid directory, throw Error // codeDir is not a valid directory, throw Error
// tslint:disable-next-line:non-literal-fs-path
if (!fs.lstatSync(remoteMachineTrailConfig.codeDir) if (!fs.lstatSync(remoteMachineTrailConfig.codeDir)
.isDirectory()) { .isDirectory()) {
throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`); throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`);
...@@ -343,6 +333,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -343,6 +333,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.trialConfig = remoteMachineTrailConfig; this.trialConfig = remoteMachineTrailConfig;
break; break;
}
case TrialConfigMetadataKey.MULTI_PHASE: case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True'); this.isMultiPhase = (value === 'true' || value === 'True');
break; break;
...@@ -444,7 +435,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -444,7 +435,6 @@ class RemoteMachineTrainingService implements TrainingService {
await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn); await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn);
//Begin to execute gpu_metrics_collection scripts //Begin to execute gpu_metrics_collection scripts
// tslint:disable-next-line: no-floating-promises
const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir); const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir);
SSHClientUtility.remoteExeCommand(`bash -c '${script}'`, conn); SSHClientUtility.remoteExeCommand(`bash -c '${script}'`, conn);
...@@ -464,7 +454,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -464,7 +454,7 @@ class RemoteMachineTrainingService implements TrainingService {
} }
private async prepareTrialJob(trialJobId: string): Promise<boolean> { private async prepareTrialJob(trialJobId: string): Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
if (this.trialConfig === undefined) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
...@@ -485,13 +475,13 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -485,13 +475,13 @@ class RemoteMachineTrainingService implements TrainingService {
// get an ssh client from scheduler // get an ssh client from scheduler
const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobDetail); const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobDetail);
if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) { if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) {
const errorMessage : string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`; const errorMessage: string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`;
this.log.error(errorMessage); this.log.error(errorMessage);
deferred.reject(); deferred.reject();
throw new NNIError(NNIErrorNames.RESOURCE_NOT_AVAILABLE, errorMessage); throw new NNIError(NNIErrorNames.RESOURCE_NOT_AVAILABLE, errorMessage);
} else if (rmScheduleResult.resultType === ScheduleResultType.SUCCEED } else if (rmScheduleResult.resultType === ScheduleResultType.SUCCEED
&& rmScheduleResult.scheduleInfo !== undefined) { && rmScheduleResult.scheduleInfo !== undefined) {
const rmScheduleInfo : RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo; const rmScheduleInfo: RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo;
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId); const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta; trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
...@@ -521,7 +511,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -521,7 +511,7 @@ class RemoteMachineTrainingService implements TrainingService {
if (this.trialConfig === undefined) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
const cuda_visible_device: string = rmScheduleInfo.cuda_visible_device; const cudaVisibleDevice: string = rmScheduleInfo.cudaVisibleDevice;
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId); const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId);
if (sshClient === undefined) { if (sshClient === undefined) {
assert(false, 'sshClient is undefined.'); assert(false, 'sshClient is undefined.');
...@@ -543,19 +533,18 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -543,19 +533,18 @@ class RemoteMachineTrainingService implements TrainingService {
// See definition in remoteMachineData.ts // See definition in remoteMachineData.ts
let command: string; let command: string;
// Set CUDA_VISIBLE_DEVICES environment variable based on cuda_visible_device // Set CUDA_VISIBLE_DEVICES environment variable based on cudaVisibleDevice
// If no valid cuda_visible_device is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device // If no valid cudaVisibleDevice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script // If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script
if (this.trialConfig.gpuNum === undefined) { if (this.trialConfig.gpuNum === undefined) {
command = this.trialConfig.command; command = this.trialConfig.command;
} else { } else {
if (typeof cuda_visible_device === 'string' && cuda_visible_device.length > 0) { if (typeof cudaVisibleDevice === 'string' && cudaVisibleDevice.length > 0) {
command = `CUDA_VISIBLE_DEVICES=${cuda_visible_device} ${this.trialConfig.command}`; command = `CUDA_VISIBLE_DEVICES=${cudaVisibleDevice} ${this.trialConfig.command}`;
} else { } else {
command = `CUDA_VISIBLE_DEVICES=" " ${this.trialConfig.command}`; command = `CUDA_VISIBLE_DEVICES=" " ${this.trialConfig.command}`;
} }
} }
// tslint:disable-next-line: strict-boolean-expressions
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address(); const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
if (this.remoteRestServerPort === undefined) { if (this.remoteRestServerPort === undefined) {
const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer); const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer);
...@@ -584,16 +573,15 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -584,16 +573,15 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally. //create tmp trial working folder locally.
await execCopydir(this.trialConfig.codeDir, trialLocalTempFolder); await execCopydir(this.trialConfig.codeDir, trialLocalTempFolder);
const installScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT; const installScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files // Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' });
// Write file content ( run.sh and parameter.cfg ) to local tmp files // Write file content ( run.sh and parameter.cfg ) to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptTrialContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptTrialContent, { encoding: 'utf8' });
await this.writeParameterFile(trialJobId, form.hyperParameters, rmScheduleInfo.rmMeta); await this.writeParameterFile(trialJobId, form.hyperParameters);
// Copy files in codeDir to remote working directory // Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS); await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS);
// Execute command in remote machine // Execute command in remote machine
// tslint:disable-next-line: no-floating-promises
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient); SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient);
} }
...@@ -610,6 +598,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -610,6 +598,7 @@ class RemoteMachineTrainingService implements TrainingService {
const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>(); const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>();
const jobpidPath: string = this.getJobPidPath(trialJob.id); const jobpidPath: string = this.getJobPidPath(trialJob.id);
const trialReturnCodeFilePath: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code'); const trialReturnCodeFilePath: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code');
/* eslint-disable require-atomic-updates */
try { try {
const killResult: number = (await SSHClientUtility.remoteExeCommand(`kill -0 \`cat ${jobpidPath}\``, sshClient)).exitCode; const killResult: number = (await SSHClientUtility.remoteExeCommand(`kill -0 \`cat ${jobpidPath}\``, sshClient)).exitCode;
// if the process of jobpid is not alive any more // if the process of jobpid is not alive any more
...@@ -646,7 +635,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -646,7 +635,7 @@ class RemoteMachineTrainingService implements TrainingService {
deferred.resolve(trialJob); deferred.resolve(trialJob);
} }
} }
/* eslint-enable require-atomic-updates */
return deferred.promise; return deferred.promise;
} }
...@@ -662,7 +651,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -662,7 +651,7 @@ class RemoteMachineTrainingService implements TrainingService {
return unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni', 'experiments', getExperimentId()); return unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni', 'experiments', getExperimentId());
} }
public get MetricsEmitter() : EventEmitter { public get MetricsEmitter(): EventEmitter {
return this.metricsEmitter; return this.metricsEmitter;
} }
...@@ -672,13 +661,10 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -672,13 +661,10 @@ class RemoteMachineTrainingService implements TrainingService {
throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${jobId}`); throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${jobId}`);
} }
let jobpidPath: string; return unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
jobpidPath = unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
return jobpidPath;
} }
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters, rmMeta: RemoteMachineMeta): Promise<void> { private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId); const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId);
if (sshClient === undefined) { if (sshClient === undefined) {
throw new Error('sshClient is undefined.'); throw new Error('sshClient is undefined.');
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
'use strict'; 'use strict';
import * as assert from 'assert'; import * as assert from 'assert';
import * as cpp from 'child-process-promise';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import { Client, ClientChannel, SFTPWrapper } from 'ssh2'; import { Client, ClientChannel, SFTPWrapper } from 'ssh2';
...@@ -22,44 +21,18 @@ import { RemoteCommandResult } from './remoteMachineData'; ...@@ -22,44 +21,18 @@ import { RemoteCommandResult } from './remoteMachineData';
* *
*/ */
export namespace SSHClientUtility { export namespace SSHClientUtility {
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
export async function copyDirectoryToRemote(localDirectory : string, remoteDirectory : string, sshClient : Client, remoteOS: string)
: Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const tmpTarName: string = `${uniqueString(10)}.tar.gz`;
const localTarPath: string = path.join(os.tmpdir(), tmpTarName);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), tmpTarName);
// Compress files in local directory to experiment root directory
await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it
await copyFileToRemote(localTarPath, remoteTarPath, sshClient);
await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it
await remoteExeCommand(`tar -oxzf ${remoteTarPath} -C ${remoteDirectory}`, sshClient);
await remoteExeCommand(`rm ${remoteTarPath}`, sshClient);
deferred.resolve();
return deferred.promise;
}
/** /**
* Copy local file to remote path * Copy local file to remote path
* @param localFilePath the path of local file * @param localFilePath the path of local file
* @param remoteFilePath the target path in remote machine * @param remoteFilePath the target path in remote machine
* @param sshClient SSH Client * @param sshClient SSH Client
*/ */
export function copyFileToRemote(localFilePath : string, remoteFilePath : string, sshClient : Client) : Promise<boolean> { export function copyFileToRemote(localFilePath: string, remoteFilePath: string, sshClient: Client): Promise<boolean> {
const log: Logger = getLogger(); const log: Logger = getLogger();
log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`); log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`);
assert(sshClient !== undefined); assert(sshClient !== undefined);
const deferred: Deferred<boolean> = new Deferred<boolean>(); const deferred: Deferred<boolean> = new Deferred<boolean>();
sshClient.sftp((err : Error, sftp : SFTPWrapper) => { sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`); log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(err); deferred.reject(err);
...@@ -67,7 +40,7 @@ export namespace SSHClientUtility { ...@@ -67,7 +40,7 @@ export namespace SSHClientUtility {
return; return;
} }
assert(sftp !== undefined); assert(sftp !== undefined);
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr : Error) => { sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr: Error) => {
sftp.end(); sftp.end();
if (fastPutErr !== undefined && fastPutErr !== null) { if (fastPutErr !== undefined && fastPutErr !== null) {
deferred.reject(fastPutErr); deferred.reject(fastPutErr);
...@@ -85,16 +58,15 @@ export namespace SSHClientUtility { ...@@ -85,16 +58,15 @@ export namespace SSHClientUtility {
* @param command the command to execute remotely * @param command the command to execute remotely
* @param client SSH Client * @param client SSH Client
*/ */
// tslint:disable:no-unsafe-any no-any export function remoteExeCommand(command: string, client: Client): Promise<RemoteCommandResult> {
export function remoteExeCommand(command : string, client : Client): Promise<RemoteCommandResult> {
const log: Logger = getLogger(); const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`); log.debug(`remoteExeCommand: command: [${command}]`);
const deferred : Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>(); const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = ''; let stdout: string = '';
let stderr: string = ''; let stderr: string = '';
let exitCode : number; let exitCode: number;
client.exec(command, (err : Error, channel : ClientChannel) => { client.exec(command, (err: Error, channel: ClientChannel) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`); log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err); deferred.reject(err);
...@@ -102,14 +74,14 @@ export namespace SSHClientUtility { ...@@ -102,14 +74,14 @@ export namespace SSHClientUtility {
return; return;
} }
channel.on('data', (data : any, dataStderr : any) => { channel.on('data', (data: any, dataStderr: any) => {
if (dataStderr !== undefined && dataStderr !== null) { if (dataStderr !== undefined && dataStderr !== null) {
stderr += data.toString(); stderr += data.toString();
} else { } else {
stdout += data.toString(); stdout += data.toString();
} }
}) })
.on('exit', (code : any, signal : any) => { .on('exit', (code: any, signal: any) => {
exitCode = <number>code; exitCode = <number>code;
deferred.resolve({ deferred.resolve({
stdout : stdout, stdout : stdout,
...@@ -122,9 +94,34 @@ export namespace SSHClientUtility { ...@@ -122,9 +94,34 @@ export namespace SSHClientUtility {
return deferred.promise; return deferred.promise;
} }
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
export async function copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, sshClient: Client, remoteOS: string): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const tmpTarName: string = `${uniqueString(10)}.tar.gz`;
const localTarPath: string = path.join(os.tmpdir(), tmpTarName);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), tmpTarName);
// Compress files in local directory to experiment root directory
await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it
await copyFileToRemote(localTarPath, remoteTarPath, sshClient);
await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it
await remoteExeCommand(`tar -oxzf ${remoteTarPath} -C ${remoteDirectory}`, sshClient);
await remoteExeCommand(`rm ${remoteTarPath}`, sshClient);
deferred.resolve();
return deferred.promise;
}
export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> { export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
sshClient.sftp((err: Error, sftp : SFTPWrapper) => { sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) { if (err !== undefined && err !== null) {
getLogger() getLogger()
.error(`getRemoteFileContent: ${err.message}`); .error(`getRemoteFileContent: ${err.message}`);
...@@ -133,10 +130,10 @@ export namespace SSHClientUtility { ...@@ -133,10 +130,10 @@ export namespace SSHClientUtility {
return; return;
} }
try { try {
const sftpStream : stream.Readable = sftp.createReadStream(filePath); const sftpStream: stream.Readable = sftp.createReadStream(filePath);
let dataBuffer: string = ''; let dataBuffer: string = '';
sftpStream.on('data', (data : Buffer | string) => { sftpStream.on('data', (data: Buffer | string) => {
dataBuffer += data; dataBuffer += data;
}) })
.on('error', (streamErr: Error) => { .on('error', (streamErr: Error) => {
...@@ -158,5 +155,4 @@ export namespace SSHClientUtility { ...@@ -158,5 +155,4 @@ export namespace SSHClientUtility {
return deferred.promise; return deferred.promise;
} }
// tslint:enable:no-unsafe-any no-any
} }
...@@ -703,7 +703,7 @@ buffer-stream-reader@^0.1.1: ...@@ -703,7 +703,7 @@ buffer-stream-reader@^0.1.1:
version "0.1.1" version "0.1.1"
resolved "https://registry.yarnpkg.com/buffer-stream-reader/-/buffer-stream-reader-0.1.1.tgz#ca8bf93631deedd8b8f8c3bb44991cc30951e259" resolved "https://registry.yarnpkg.com/buffer-stream-reader/-/buffer-stream-reader-0.1.1.tgz#ca8bf93631deedd8b8f8c3bb44991cc30951e259"
builtin-modules@^1.0.0, builtin-modules@^1.1.1: builtin-modules@^1.0.0:
version "1.1.1" version "1.1.1"
resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-1.1.1.tgz#270f076c5a72c02f5b65a47df94c5fe3a278892f" resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-1.1.1.tgz#270f076c5a72c02f5b65a47df94c5fe3a278892f"
...@@ -841,7 +841,7 @@ chalk@^1.0.0: ...@@ -841,7 +841,7 @@ chalk@^1.0.0:
strip-ansi "^3.0.0" strip-ansi "^3.0.0"
supports-color "^2.0.0" supports-color "^2.0.0"
chalk@^2.0.0, chalk@^2.3.0: chalk@^2.0.0:
version "2.4.1" version "2.4.1"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e" resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e"
dependencies: dependencies:
...@@ -971,10 +971,6 @@ commander@2.15.1: ...@@ -971,10 +971,6 @@ commander@2.15.1:
version "2.15.1" version "2.15.1"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.15.1.tgz#df46e867d0fc2aec66a34662b406a9ccafff5b0f" resolved "https://registry.yarnpkg.com/commander/-/commander-2.15.1.tgz#df46e867d0fc2aec66a34662b406a9ccafff5b0f"
commander@^2.12.1:
version "2.16.0"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.16.0.tgz#f16390593996ceb4f3eeb020b31d78528f7f8a50"
commander@~2.17.1: commander@~2.17.1:
version "2.17.1" version "2.17.1"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.17.1.tgz#bd77ab7de6de94205ceacc72f1716d29f20a77bf" resolved "https://registry.yarnpkg.com/commander/-/commander-2.17.1.tgz#bd77ab7de6de94205ceacc72f1716d29f20a77bf"
...@@ -1134,7 +1130,7 @@ debug@^4.0.1, debug@^4.1.0, debug@^4.1.1: ...@@ -1134,7 +1130,7 @@ debug@^4.0.1, debug@^4.1.0, debug@^4.1.1:
dependencies: dependencies:
ms "^2.1.1" ms "^2.1.1"
debuglog@*, debuglog@^1.0.1: debuglog@^1.0.1:
version "1.0.1" version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492" resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
...@@ -1217,7 +1213,7 @@ dezalgo@^1.0.0, dezalgo@~1.0.3: ...@@ -1217,7 +1213,7 @@ dezalgo@^1.0.0, dezalgo@~1.0.3:
asap "^2.0.0" asap "^2.0.0"
wrappy "1" wrappy "1"
diff@3.5.0, diff@^3.1.0, diff@^3.2.0: diff@3.5.0, diff@^3.1.0:
version "3.5.0" version "3.5.0"
resolved "https://registry.yarnpkg.com/diff/-/diff-3.5.0.tgz#800c0dd1e0a8bfbc95835c202ad220fe317e5a12" resolved "https://registry.yarnpkg.com/diff/-/diff-3.5.0.tgz#800c0dd1e0a8bfbc95835c202ad220fe317e5a12"
...@@ -2080,7 +2076,7 @@ import-lazy@^2.1.0: ...@@ -2080,7 +2076,7 @@ import-lazy@^2.1.0:
version "2.1.0" version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43" resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*, imurmurhash@^0.1.4: imurmurhash@^0.1.4:
version "0.1.4" version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea" resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
...@@ -2519,10 +2515,6 @@ lockfile@~1.0.3: ...@@ -2519,10 +2515,6 @@ lockfile@~1.0.3:
dependencies: dependencies:
signal-exit "^3.0.2" signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
lodash._baseuniq@~4.6.0: lodash._baseuniq@~4.6.0:
version "4.6.0" version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8" resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
...@@ -2530,28 +2522,10 @@ lodash._baseuniq@~4.6.0: ...@@ -2530,28 +2522,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0" lodash._createset "~4.0.0"
lodash._root "~3.0.0" lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0: lodash._createset@~4.0.0:
version "4.0.3" version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26" resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
lodash._root@~3.0.0: lodash._root@~3.0.0:
version "3.0.1" version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692" resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
...@@ -2600,10 +2574,6 @@ lodash.pick@^4.4.0: ...@@ -2600,10 +2574,6 @@ lodash.pick@^4.4.0:
version "4.4.0" version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3" resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
lodash.unescape@4.0.1: lodash.unescape@4.0.1:
version "4.0.1" version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c" resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
...@@ -3519,10 +3489,6 @@ path-key@^2.0.0, path-key@^2.0.1: ...@@ -3519,10 +3489,6 @@ path-key@^2.0.0, path-key@^2.0.1:
version "2.0.1" version "2.0.1"
resolved "https://registry.yarnpkg.com/path-key/-/path-key-2.0.1.tgz#411cadb574c5a140d3a4b1910d40d80cc9f40b40" resolved "https://registry.yarnpkg.com/path-key/-/path-key-2.0.1.tgz#411cadb574c5a140d3a4b1910d40d80cc9f40b40"
path-parse@^1.0.5:
version "1.0.5"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.5.tgz#3c1adf871ea9cd6c9431b6ea2bd74a0ff055c4c1"
path-parse@^1.0.6: path-parse@^1.0.6:
version "1.0.6" version "1.0.6"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.6.tgz#d62dbb5679405d72c4737ec58600e9ddcf06d24c" resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.6.tgz#d62dbb5679405d72c4737ec58600e9ddcf06d24c"
...@@ -3834,7 +3800,7 @@ readable-stream@~2.0.0: ...@@ -3834,7 +3800,7 @@ readable-stream@~2.0.0:
string_decoder "~0.10.x" string_decoder "~0.10.x"
util-deprecate "~1.0.1" util-deprecate "~1.0.1"
readdir-scoped-modules@*, readdir-scoped-modules@^1.0.0: readdir-scoped-modules@^1.0.0:
version "1.1.0" version "1.1.0"
resolved "https://registry.yarnpkg.com/readdir-scoped-modules/-/readdir-scoped-modules-1.1.0.tgz#8d45407b4f870a0dcaebc0e28670d18e74514309" resolved "https://registry.yarnpkg.com/readdir-scoped-modules/-/readdir-scoped-modules-1.1.0.tgz#8d45407b4f870a0dcaebc0e28670d18e74514309"
dependencies: dependencies:
...@@ -3977,12 +3943,6 @@ resolve@^1.10.0: ...@@ -3977,12 +3943,6 @@ resolve@^1.10.0:
dependencies: dependencies:
path-parse "^1.0.6" path-parse "^1.0.6"
resolve@^1.3.2:
version "1.8.1"
resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.8.1.tgz#82f1ec19a423ac1fbd080b0bab06ba36e84a7a26"
dependencies:
path-parse "^1.0.5"
responselike@1.0.2: responselike@1.0.2:
version "1.0.2" version "1.0.2"
resolved "https://registry.yarnpkg.com/responselike/-/responselike-1.0.2.tgz#918720ef3b631c5642be068f15ade5a46f4ba1e7" resolved "https://registry.yarnpkg.com/responselike/-/responselike-1.0.2.tgz#918720ef3b631c5642be068f15ade5a46f4ba1e7"
...@@ -4599,7 +4559,7 @@ ts-node@^7.0.0: ...@@ -4599,7 +4559,7 @@ ts-node@^7.0.0:
source-map-support "^0.5.6" source-map-support "^0.5.6"
yn "^2.0.0" yn "^2.0.0"
tslib@^1.8.0, tslib@^1.8.1: tslib@^1.8.1:
version "1.9.3" version "1.9.3"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
...@@ -4607,42 +4567,6 @@ tslib@^1.9.0: ...@@ -4607,42 +4567,6 @@ tslib@^1.9.0:
version "1.10.0" version "1.10.0"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.10.0.tgz#c3c19f95973fb0a62973fb09d90d961ee43e5c8a" resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.10.0.tgz#c3c19f95973fb0a62973fb09d90d961ee43e5c8a"
tslint-microsoft-contrib@^6.0.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/tslint-microsoft-contrib/-/tslint-microsoft-contrib-6.2.0.tgz#8aa0f40584d066d05e6a5e7988da5163b85f2ad4"
dependencies:
tsutils "^2.27.2 <2.29.0"
tslint@^5.12.0:
version "5.18.0"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.18.0.tgz#f61a6ddcf372344ac5e41708095bbf043a147ac6"
dependencies:
"@babel/code-frame" "^7.0.0"
builtin-modules "^1.1.1"
chalk "^2.3.0"
commander "^2.12.1"
diff "^3.2.0"
glob "^7.1.1"
js-yaml "^3.13.1"
minimatch "^3.0.4"
mkdirp "^0.5.1"
resolve "^1.3.2"
semver "^5.3.0"
tslib "^1.8.0"
tsutils "^2.29.0"
"tsutils@^2.27.2 <2.29.0":
version "2.28.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.28.0.tgz#6bd71e160828f9d019b6f4e844742228f85169a1"
dependencies:
tslib "^1.8.1"
tsutils@^2.29.0:
version "2.29.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.29.0.tgz#32b488501467acbedd4b85498673a0812aca0b99"
dependencies:
tslib "^1.8.1"
tsutils@^3.17.1: tsutils@^3.17.1:
version "3.17.1" version "3.17.1"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.17.1.tgz#ed719917f11ca0dee586272b2ac49e015a2dd759" resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.17.1.tgz#ed719917f11ca0dee586272b2ac49e015a2dd759"
...@@ -4818,7 +4742,7 @@ v8-compile-cache@^2.0.3: ...@@ -4818,7 +4742,7 @@ v8-compile-cache@^2.0.3:
version "2.1.0" version "2.1.0"
resolved "https://registry.yarnpkg.com/v8-compile-cache/-/v8-compile-cache-2.1.0.tgz#e14de37b31a6d194f5690d67efc4e7f6fc6ab30e" resolved "https://registry.yarnpkg.com/v8-compile-cache/-/v8-compile-cache-2.1.0.tgz#e14de37b31a6d194f5690d67efc4e7f6fc6ab30e"
validate-npm-package-license@*, validate-npm-package-license@^3.0.1: validate-npm-package-license@^3.0.1:
version "3.0.4" version "3.0.4"
resolved "https://registry.yarnpkg.com/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz#fc91f6b9c7ba15c857f4cb2c5defeec39d4f410a" resolved "https://registry.yarnpkg.com/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz#fc91f6b9c7ba15c857f4cb2c5defeec39d4f410a"
dependencies: dependencies:
......
...@@ -24,13 +24,15 @@ class BatchTuner(Tuner): ...@@ -24,13 +24,15 @@ class BatchTuner(Tuner):
Examples Examples
-------- --------
The search space only be accepted like: The search space only be accepted like:
```
{ ::
'combine_params': { '_type': 'choice',
'_value': '[{...}, {...}, {...}]', {'combine_params':
} { '_type': 'choice',
} '_value': '[{...}, {...}, {...}]',
``` }
}
""" """
def __init__(self): def __init__(self):
......
...@@ -5,7 +5,7 @@ import logging ...@@ -5,7 +5,7 @@ import logging
import torch import torch
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner', 'L1FilterPruner', 'SlimPruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
...@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner): ...@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner):
self.if_init_list[k] = True self.if_init_list[k] = True
class FPGMPruner(Pruner): class SlimPruner(Pruner):
""" """
A filter pruner via geometric median. A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
https://arxiv.org/pdf/1811.00250.pdf "Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
""" """
Parameters Parameters
---------- ----------
model : pytorch model config_list : list
the model user wants to compress
config_list: list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_dict = {} self.mask_calculated_ops = set()
self.epoch_pruned_layers = set() weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
""" """
Supports Conv1d, Conv2d Calculate the mask of given layer.
filter dimensions for Conv1d: Scale factors with the smallest absolute value in the BN layer are masked.
OUT: number of output channel
IN: number of input channel
LEN: filter length
filter dimensions for Conv2d:
OUT: number of output channel
IN: number of input channel
H: filter height
W: filter width
Parameters Parameters
---------- ----------
layer : LayerInfo layer : LayerInfo
calculate mask for `layer`'s weight the layer to instrument the compression operation
config : dict config : dict
the configuration for generating the mask layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
""" """
weight = layer.module.weight.data weight = layer.module.weight.data
assert 0 <= config.get('sparsity') < 1 op_name = layer.name
assert layer.type in ['Conv1d', 'Conv2d'] op_type = layer.type
assert layer.type in config['op_types'] assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
if layer.name in self.epoch_pruned_layers: return mask
assert layer.name in self.mask_dict
return self.mask_dict.get(layer.name)
masks = torch.ones(weight.size()).type_as(weight)
try: class RankFilterPruner(Pruner):
num_filters = weight.size(0) """
num_prune = int(num_filters * config.get('sparsity')) A structured pruning base class that prunes the filters with the smallest
if num_filters < 2 or num_prune < 1: importance criterion in convolution layers to achieve a preset level of network sparsity.
return masks """
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
masks[idx] = 0.
finally:
self.mask_dict.update({layer.name: masks})
self.epoch_pruned_layers.add(layer.name)
return masks def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
def _get_min_gm_kernel_idx(self, weight, n): super().__init__(model, config_list)
assert len(weight.size()) in [3, 4] self.mask_calculated_ops = set()
dist_list = [] def _get_mask(self, base_mask, weight, num_prune):
for out_i in range(weight.size(0)): return torch.ones(weight.size()).type_as(weight)
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx): def calc_mask(self, layer, config):
""" """
Calculate the total distance between a specified filter (by out_idex and in_idx) and Calculate the mask of given layer.
all other filters. Filters with the smallest importance criterion of the kernel weights are masked.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters Parameters
---------- ----------
weight: Tensor layer : LayerInfo
convolutional filter weight the layer to instrument the compression operation
out_idx: int config : dict
output channel index of specified filter, this method calculates the total distance layer's pruning config
between this specified filter and all other filters.
Returns Returns
------- -------
float32 torch.Tensor
The total distance mask of the layer's weight
""" """
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch): weight = layer.module.weight.data
self.epoch_pruned_layers = set() op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1
assert op_type in ['Conv1d', 'Conv2d']
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
mask = self._get_mask(mask, weight, num_prune)
finally:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
return mask.detach()
class L1FilterPruner(Pruner): class L1FilterPruner(RankFilterPruner):
""" """
A structured pruning algorithm that prunes the filters of smallest magnitude A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity. weights sum in the convolution layers to achieve a preset level of network sparsity.
...@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner): ...@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner):
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set()
def calc_mask(self, layer, config): def _get_mask(self, base_mask, weight, num_prune):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked. Filters with the smallest sum of its absolute kernel weights are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo base_mask : torch.Tensor
the layer to instrument the compression operation The basic mask with the same shape of weight, all item in the basic mask is 1.
config : dict weight : torch.Tensor
layer's pruning config Layer's weight
num_prune : int
Num of filters to prune
Returns Returns
------- -------
torch.Tensor torch.Tensor
mask of the layer's weight Mask of the layer's weight
""" """
weight = layer.module.weight.data filters = weight.shape[0]
op_name = layer.name w_abs = weight.abs()
op_type = layer.type w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
assert op_type == 'Conv2d', 'L1FilterPruner only supports 2d convolution layer pruning' threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
if op_name in self.mask_calculated_ops: mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.shape[0]
w_abs = weight.abs()
k = int(filters * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask return mask
class SlimPruner(Pruner): class L2FilterPruner(RankFilterPruner):
""" """
A structured pruning algorithm that prunes channels by pruning the weights of BN layers. A structured pruning algorithm that prunes the filters with the
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang smallest L2 norm of the absolute kernel weights are masked.
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
""" """
Parameters Parameters
---------- ----------
model : torch.nn.module
Model to be pruned
config_list : list config_list : list
support key for each list item: support key for each list item:
- sparsity: percentage of convolutional filters to be pruned. - sparsity: percentage of convolutional filters to be pruned.
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config): def _get_mask(self, base_mask, weight, num_prune):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked. Filters with the smallest L2 norm of the absolute kernel weights are masked.
Parameters Parameters
---------- ----------
layer : LayerInfo base_mask : torch.Tensor
the layer to instrument the compression operation The basic mask with the same shape of weight, all item in the basic mask is 1.
config : dict weight : torch.Tensor
layer's pruning config Layer's weight
num_prune : int
Num of filters to prune
Returns Returns
------- -------
torch.Tensor torch.Tensor
mask of the layer's weight Mask of the layer's weight
""" """
filters = weight.shape[0]
weight = layer.module.weight.data w = weight.view(filters, -1)
op_name = layer.name w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
op_type = layer.type threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' mask = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask return mask
class FPGMPruner(RankFilterPruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
def _get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
Mask of the layer's weight
"""
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
base_mask[idx] = 0.
return base_mask
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch):
self.mask_calculated_ops = set()
...@@ -163,19 +163,23 @@ class MsgDispatcherBase(Recoverable): ...@@ -163,19 +163,23 @@ class MsgDispatcherBase(Recoverable):
raise NotImplementedError('handle_initialize not implemented') raise NotImplementedError('handle_initialize not implemented')
def handle_request_trial_jobs(self, data): def handle_request_trial_jobs(self, data):
"""The message dispatcher is demanded to generate `data` trial jobs. """The message dispatcher is demanded to generate ``data`` trial jobs.
These trial jobs should be sent via `send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`, These trial jobs should be sent via ``send(CommandType.NewTrialJob, json_tricks.dumps(parameter))``,
where `parameter` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter". where ``parameter`` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this `send` exactly `data` times. Semantically, message dispatcher should do this ``send`` exactly ``data`` times.
The JSON sent by this method should follow the format of The JSON sent by this method should follow the format of
{
"parameter_id": 42 ::
"parameters": {
// this will be received by trial {
}, "parameter_id": 42
"parameter_source": "algorithm" // optional "parameters": {
} // this will be received by trial
},
"parameter_source": "algorithm" // optional
}
Parameters Parameters
---------- ----------
data: int data: int
...@@ -211,6 +215,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -211,6 +215,7 @@ class MsgDispatcherBase(Recoverable):
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
"""Called when metric data is reported or new parameters are requested (for multiphase). """Called when metric data is reported or new parameters are requested (for multiphase).
When new parameters are requested, this method should send a new parameter. When new parameters are requested, this method should send a new parameter.
Parameters Parameters
---------- ----------
data: dict data: dict
...@@ -219,6 +224,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -219,6 +224,7 @@ class MsgDispatcherBase(Recoverable):
`REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case, `REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case,
the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py` the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py`
as an example. as an example.
Raises Raises
------ ------
ValueError ValueError
...@@ -228,6 +234,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -228,6 +234,7 @@ class MsgDispatcherBase(Recoverable):
def handle_trial_end(self, data): def handle_trial_end(self, data):
"""Called when the state of one of the trials is changed """Called when the state of one of the trials is changed
Parameters Parameters
---------- ----------
data: dict data: dict
...@@ -235,5 +242,6 @@ class MsgDispatcherBase(Recoverable): ...@@ -235,5 +242,6 @@ class MsgDispatcherBase(Recoverable):
trial_job_id: the id generated by training service. trial_job_id: the id generated by training service.
event: the job’s state. event: the job’s state.
hyper_params: the string that is sent by message dispatcher during the creation of trials. hyper_params: the string that is sent by message dispatcher during the creation of trials.
""" """
raise NotImplementedError('handle_trial_end not implemented') raise NotImplementedError('handle_trial_end not implemented')
...@@ -58,8 +58,9 @@ def tf2(func): ...@@ -58,8 +58,9 @@ def tf2(func):
return test_tf2_func return test_tf2_func
# for fpgm filter pruner test # for fpgm filter pruner test
w = np.array([[[[i+1]*3]*3]*5 for i in range(10)]) w = np.array([[[[i + 1] * 3] * 3] * 5 for i in range(10)])
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
...@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase): ...@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase):
config_list = [{ config_list = [{
'quant_types': ['weight'], 'quant_types': ['weight'],
'quant_bits': 8, 'quant_bits': 8,
'op_types':['Conv2d', 'Linear'] 'op_types': ['Conv2d', 'Linear']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
'quant_bits': 8, 'quant_bits': 8,
'quant_start_step': 0, 'quant_start_step': 0,
'op_types':['ReLU'] 'op_types': ['ReLU']
}] }]
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list) quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress() quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress() modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [ t[0].name for t in modules_to_compress] modules_to_compress_name = [t[0].name for t in modules_to_compress]
assert "conv1" in modules_to_compress_name assert "conv1" in modules_to_compress_name
assert "conv2" in modules_to_compress_name assert "conv2" in modules_to_compress_name
assert "fc1" in modules_to_compress_name assert "fc1" in modules_to_compress_name
...@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase): ...@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase):
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2, w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4]) np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel() model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}] config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
{'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list) pruner = torch_compressor.L1FilterPruner(model, config_list)
model.conv1.weight.data = torch.tensor(w).float() model.conv1.weight.data = torch.tensor(w).float()
...@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase): ...@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase):
config_list = [{ config_list = [{
'quant_types': ['weight'], 'quant_types': ['weight'],
'quant_bits': 8, 'quant_bits': 8,
'op_types':['Conv2d', 'Linear'] 'op_types': ['Conv2d', 'Linear']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
'quant_bits': 8, 'quant_bits': 8,
'quant_start_step': 0, 'quant_start_step': 0,
'op_types':['ReLU'] 'op_types': ['ReLU']
}] }]
model.relu = torch.nn.ReLU() model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list) quantizer = torch_compressor.QAT_Quantizer(model, config_list)
...@@ -253,7 +255,7 @@ class CompressorTestCase(TestCase): ...@@ -253,7 +255,7 @@ class CompressorTestCase(TestCase):
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2) quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 5 / 255, abs_tol=eps) assert math.isclose(model.conv2.scale, 5 / 255, abs_tol=eps)
assert model.conv2.zero_point == 0 assert model.conv2.zero_point == 0
# range including 0 # range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float() weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2) quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 6 / 255, abs_tol=eps) assert math.isclose(model.conv2.scale, 6 / 255, abs_tol=eps)
...@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase): ...@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase):
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps) assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment