Unverified Commit 543239c6 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #220 from microsoft/master

merge master
parents 32efaa36 659480f2
......@@ -3,7 +3,6 @@
'use strict';
// tslint:disable-next-line:no-implicit-dependencies
import * as request from 'request';
import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from '../../common/errors';
......@@ -16,10 +15,10 @@ import { PAITrialJobDetail } from './paiData';
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
export class PAIJobInfoCollector {
private readonly trialJobsMap : Map<string, PAITrialJobDetail>;
private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
private readonly log: Logger = getLogger();
private readonly statusesNeedToCheck : TrialJobStatus[];
private readonly finalStatuses : TrialJobStatus[];
private readonly statusesNeedToCheck: TrialJobStatus[];
private readonly finalStatuses: TrialJobStatus[];
constructor(jobMap: Map<string, PAITrialJobDetail>) {
this.trialJobsMap = jobMap;
......@@ -27,12 +26,12 @@ export class PAIJobInfoCollector {
this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'];
}
public async retrieveTrialStatus(paiToken? : string, paiClusterConfig?: PAIClusterConfig) : Promise<void> {
public async retrieveTrialStatus(paiToken? : string, paiClusterConfig?: PAIClusterConfig): Promise<void> {
if (paiClusterConfig === undefined || paiToken === undefined) {
return Promise.resolve();
}
const updatePaiTrialJobs : Promise<void>[] = [];
const updatePaiTrialJobs: Promise<void>[] = [];
for (const [trialJobId, paiTrialJob] of this.trialJobsMap) {
if (paiTrialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`);
......@@ -43,9 +42,8 @@ export class PAIJobInfoCollector {
await Promise.all(updatePaiTrialJobs);
}
private getSinglePAITrialJobInfo(paiTrialJob : PAITrialJobDetail, paiToken : string, paiClusterConfig: PAIClusterConfig)
: Promise<void> {
const deferred : Deferred<void> = new Deferred<void>();
private getSinglePAITrialJobInfo(paiTrialJob: PAITrialJobDetail, paiToken: string, paiClusterConfig: PAIClusterConfig): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) {
deferred.resolve();
......@@ -55,7 +53,6 @@ export class PAIJobInfoCollector {
// Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const getJobInfoRequest: request.Options = {
// tslint:disable-next-line:no-http-string
uri: `http://${paiClusterConfig.host}/rest-server/api/v1/user/${paiClusterConfig.userName}/jobs/${paiTrialJob.paiJobName}`,
method: 'GET',
json: true,
......@@ -65,7 +62,6 @@ export class PAIJobInfoCollector {
}
};
// tslint:disable: no-unsafe-any no-any cyclomatic-complexity
//TODO : pass in request timeout param?
request(getJobInfoRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 500) {
......@@ -129,5 +125,4 @@ export class PAIJobInfoCollector {
return deferred.promise;
}
// tslint:enable: no-unsafe-any no-any
}
......@@ -24,7 +24,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
private parameterFileMetaList: ParameterFileMeta[] = [];
@Inject
private readonly paiTrainingService : PAITrainingService;
private readonly paiTrainingService: PAITrainingService;
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
......@@ -34,8 +34,7 @@ export class PAIJobRestServer extends ClusterJobRestServer {
this.paiTrainingService = component.get(PAITrainingService);
}
// tslint:disable-next-line:no-any
protected handleTrialMetrics(jobId : string, metrics : any[]) : void {
protected handleTrialMetrics(jobId: string, metrics: any[]): void {
// Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWN
for (const singleMetric of metrics) {
......
......@@ -3,17 +3,14 @@
'use strict';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import * as path from 'path';
// tslint:disable-next-line:no-implicit-dependencies
import * as request from 'request';
import * as component from '../../common/component';
import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import { MethodNotImplementedError } from '../../common/errors';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
......@@ -47,13 +44,12 @@ class PAITrainingService implements TrainingService {
private paiClusterConfig?: PAIClusterConfig;
private readonly jobQueue: string[];
private stopping: boolean = false;
// tslint:disable-next-line:no-any
private hdfsClient: any;
private paiToken? : string;
private paiTokenUpdateTime?: number;
private readonly paiTokenUpdateInterval: number;
private readonly experimentId! : string;
private readonly paiJobCollector : PAIJobInfoCollector;
private readonly experimentId!: string;
private readonly paiJobCollector: PAIJobInfoCollector;
private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig;
private copyExpCodeDirPromise?: Promise<void>;
......@@ -126,7 +122,7 @@ class PAITrainingService implements TrainingService {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`);
}
const deferred : Deferred<PAITrialJobDetail> = new Deferred<PAITrialJobDetail>();
const deferred: Deferred<PAITrialJobDetail> = new Deferred<PAITrialJobDetail>();
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
......@@ -137,7 +133,7 @@ class PAITrainingService implements TrainingService {
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
const hdfsLogPath : string = String.Format(
const hdfsLogPath: string = String.Format(
PAI_LOG_PATH_FORMAT,
this.paiClusterConfig.host,
hdfsOutputDir
......@@ -173,10 +169,9 @@ class PAITrainingService implements TrainingService {
return true;
}
// tslint:disable:no-http-string
public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJobDetail : PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
const deferred : Deferred<void> = new Deferred<void>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
const deferred: Deferred<void> = new Deferred<void>();
if (trialJobDetail === undefined) {
this.log.error(`cancelTrialJob: trial job id ${trialJobId} not found`);
......@@ -205,7 +200,6 @@ class PAITrainingService implements TrainingService {
// Set trialjobDetail's early stopped field, to mark the job's cancellation source
trialJobDetail.isEarlyStopped = isEarlyStopped;
// tslint:disable-next-line:no-any
request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
......@@ -219,10 +213,8 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
// tslint:disable: no-unsafe-any no-any
// tslint:disable-next-line:max-func-body-length
public async setClusterMetadata(key: string, value: string): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>();
const deferred: Deferred<void> = new Deferred<void>();
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
......@@ -300,10 +292,9 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
// tslint:enable: no-unsafe-any
public getClusterMetadata(key: string): Promise<string> {
const deferred : Deferred<string> = new Deferred<string>();
const deferred: Deferred<string> = new Deferred<string>();
deferred.resolve();
......@@ -314,14 +305,13 @@ class PAITrainingService implements TrainingService {
this.log.info('Stopping PAI training service...');
this.stopping = true;
const deferred : Deferred<void> = new Deferred<void>();
const deferred: Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
try {
await restServer.stop();
deferred.resolve();
this.log.info('PAI Training service rest server stopped successfully.');
} catch (error) {
// tslint:disable-next-line: no-unsafe-any
this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
deferred.reject(error);
}
......@@ -329,13 +319,12 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
public get MetricsEmitter() : EventEmitter {
public get MetricsEmitter(): EventEmitter {
return this.metricsEmitter;
}
// tslint:disable-next-line:max-func-body-length
private async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>();
const deferred: Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
......@@ -372,7 +361,7 @@ class PAITrainingService implements TrainingService {
//create tmp trial working folder locally.
await execMkdir(trialLocalTempFolder);
const runScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
......@@ -385,10 +374,9 @@ class PAITrainingService implements TrainingService {
}
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
// tslint:disable-next-line: strict-boolean-expressions
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : '';
const nniPaiTrialCommand : string = String.Format(
const nniPaiTrialCommand: string = String.Format(
PAI_TRIAL_COMMAND_FORMAT,
// PAI will copy job's codeDir into /root directory
`$PWD/${trialJobId}`,
......@@ -409,9 +397,8 @@ class PAITrainingService implements TrainingService {
)
.replace(/\r\n|\n|\r/gm, '');
// tslint:disable-next-line:no-console
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiTaskRoles : PAITaskRole[] = [
const paiTaskRoles: PAITaskRole[] = [
new PAITaskRole(
`nni_trail_${trialJobId}`,
// Task role number
......@@ -431,7 +418,7 @@ class PAITrainingService implements TrainingService {
)
];
const paiJobConfig : PAIJobConfig = new PAIJobConfig(
const paiJobConfig: PAIJobConfig = new PAIJobConfig(
// Job name
trialJobDetail.paiJobName,
// Docker image
......@@ -451,7 +438,7 @@ class PAITrainingService implements TrainingService {
await HDFSClientUtility.copyDirectoryToHdfs(trialLocalTempFolder, hdfsCodeDir, this.hdfsClient);
} catch (error) {
this.log.error(`PAI Training service: copy ${this.paiTrialConfig.codeDir} to HDFS ${hdfsCodeDir} failed, error is ${error}`);
trialJobDetail.status = 'FAILED';
trialJobDetail.status = 'FAILED'; // eslint-disable-line require-atomic-updates
deferred.resolve(true);
return deferred.promise;
......@@ -469,10 +456,9 @@ class PAITrainingService implements TrainingService {
Authorization: `Bearer ${this.paiToken}`
}
};
// tslint:disable:no-any no-unsafe-any
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage : string = (error !== undefined && error !== null) ? error.message :
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body.message}`;
trialJobDetail.status = 'FAILED';
deferred.resolve(true);
......@@ -527,7 +513,7 @@ class PAITrainingService implements TrainingService {
* Update pai token by the interval time or initialize the pai token
*/
private async updatePaiToken(): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>();
const deferred: Deferred<void> = new Deferred<void>();
const currentTime: number = new Date().getTime();
//If pai token initialized and not reach the interval time, do not update
......@@ -603,7 +589,7 @@ class PAITrainingService implements TrainingService {
}
private postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>();
const deferred: Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
const req: request.Options = {
uri: `${restServer.endPoint}${restServer.apiRootUrl}/parameter-file-meta`,
......
......@@ -15,7 +15,7 @@ export class PAITrialConfig extends TrialConfig {
public readonly dataDir: string;
public readonly outputDir: string;
constructor(command : string, codeDir : string, gpuNum : number, cpuNum: number, memoryMB: number,
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, dataDir: string, outputDir: string) {
super(command, codeDir, gpuNum);
this.cpuNum = cpuNum;
......
......@@ -5,7 +5,6 @@
import * as assert from 'assert';
import { getLogger, Logger } from '../../common/log';
import { TrialJobDetail } from '../../common/trainingService';
import { randomSelect } from '../../common/utils';
import { GPUInfo } from '../common/gpuData';
import {
......@@ -19,7 +18,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/
export class GPUScheduler {
private readonly machineSSHClientMap : Map<RemoteMachineMeta, SSHClientManager>;
private readonly machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>;
private readonly log: Logger = getLogger();
private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin';
private roundRobinIndex: number = 0;
......@@ -29,7 +28,7 @@ export class GPUScheduler {
* Constructor
* @param machineSSHClientMap map from remote machine to sshClient
*/
constructor(machineSSHClientMap : Map<RemoteMachineMeta, SSHClientManager>) {
constructor(machineSSHClientMap: Map<RemoteMachineMeta, SSHClientManager>) {
assert(machineSSHClientMap.size > 0);
this.machineSSHClientMap = machineSSHClientMap;
this.configuredRMs = Array.from(machineSSHClientMap.keys());
......@@ -39,7 +38,7 @@ export class GPUScheduler {
* Schedule a machine according to the constraints (requiredGPUNum)
* @param requiredGPUNum required GPU number
*/
public scheduleMachine(requiredGPUNum: number | undefined, trialJobDetail : RemoteMachineTrialJobDetail) : RemoteMachineScheduleResult {
public scheduleMachine(requiredGPUNum: number | undefined, trialJobDetail: RemoteMachineTrialJobDetail): RemoteMachineScheduleResult {
if(requiredGPUNum === undefined) {
requiredGPUNum = 0;
}
......@@ -48,7 +47,7 @@ export class GPUScheduler {
assert(allRMs.length > 0);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines
const eligibleRM: RemoteMachineMeta[] = allRMs.filter((rmMeta : RemoteMachineMeta) =>
const eligibleRM: RemoteMachineMeta[] = allRMs.filter((rmMeta: RemoteMachineMeta) =>
rmMeta.gpuSummary === undefined || requiredGPUNum === 0 || (requiredGPUNum !== undefined && rmMeta.gpuSummary.gpuCount >= requiredGPUNum));
if (eligibleRM.length === 0) {
// If the required gpu number exceeds the upper limit of all machine's GPU number
......@@ -134,8 +133,8 @@ export class GPUScheduler {
* @param availableGPUMap available GPU resource filled by this detection
* @returns Available GPU number on this remote machine
*/
private gpuResourceDetection() : Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap : Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
this.machineSSHClientMap.forEach((sshClientManager: SSHClientManager, rmMeta: RemoteMachineMeta) => {
// Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) {
......@@ -149,7 +148,6 @@ export class GPUScheduler {
}
}
this.log.debug(`designated gpu indices: ${designatedGpuIndices}`);
// tslint:disable: strict-boolean-expressions
rmMeta.gpuSummary.gpuInfos.forEach((gpuInfo: GPUInfo) => {
// if the GPU has active process, OR be reserved by a job,
// or index not in gpuIndices configuration in machineList,
......@@ -175,7 +173,6 @@ export class GPUScheduler {
return totalResourceMap;
}
// tslint:enable: strict-boolean-expressions
private selectMachine(rmMetas: RemoteMachineMeta[]): RemoteMachineMeta {
assert(rmMetas !== undefined && rmMetas.length > 0);
......@@ -224,11 +221,11 @@ export class GPUScheduler {
resultType: ScheduleResultType.SUCCEED,
scheduleInfo: {
rmMeta: rmMeta,
cuda_visible_device: allocatedGPUs
.map((gpuInfo: GPUInfo) => {
return gpuInfo.index;
})
.join(',')
cudaVisibleDevice: allocatedGPUs
.map((gpuInfo: GPUInfo) => {
return gpuInfo.index;
})
.join(',')
}
};
}
......
......@@ -13,13 +13,13 @@ import { GPUInfo, GPUSummary } from '../common/gpuData';
* Metadata of remote machine for configuration and statuc query
*/
export class RemoteMachineMeta {
public readonly ip : string = '';
public readonly port : number = 22;
public readonly username : string = '';
public readonly ip: string = '';
public readonly port: number = 22;
public readonly username: string = '';
public readonly passwd: string = '';
public readonly sshKeyPath?: string;
public readonly passphrase?: string;
public gpuSummary : GPUSummary | undefined;
public gpuSummary: GPUSummary | undefined;
public readonly gpuIndices?: string;
public readonly maxTrialNumPerGpu?: number;
//TODO: initialize varialbe in constructor
......@@ -43,11 +43,11 @@ export function parseGpuIndices(gpuIndices?: string): Set<number> | undefined {
* The execution result for command executed on remote machine
*/
export class RemoteCommandResult {
public readonly stdout : string;
public readonly stderr : string;
public readonly exitCode : number;
public readonly stdout: string;
public readonly stderr: string;
public readonly exitCode: number;
constructor(stdout : string, stderr : string, exitCode : number) {
constructor(stdout: string, stderr: string, exitCode: number) {
this.stdout = stdout;
this.stderr = stderr;
this.exitCode = exitCode;
......@@ -186,7 +186,6 @@ export class SSHClientManager {
/**
* Create a new ssh connection client and initialize it
*/
// tslint:disable:non-literal-fs-path
private initNewSSHClient(): Promise<Client> {
const deferred: Deferred<Client> = new Deferred<Client>();
const conn: Client = new Client();
......@@ -225,9 +224,9 @@ export class SSHClientManager {
}
}
export type RemoteMachineScheduleResult = { scheduleInfo : RemoteMachineScheduleInfo | undefined; resultType : ScheduleResultType};
export type RemoteMachineScheduleResult = { scheduleInfo: RemoteMachineScheduleInfo | undefined; resultType: ScheduleResultType};
export type RemoteMachineScheduleInfo = { rmMeta : RemoteMachineMeta; cuda_visible_device : string};
export type RemoteMachineScheduleInfo = { rmMeta: RemoteMachineMeta; cudaVisibleDevice: string};
export enum ScheduleResultType {
// Schedule succeeded
......
......@@ -15,7 +15,7 @@ import { RemoteMachineTrainingService } from './remoteMachineTrainingService';
@component.Singleton
export class RemoteMachineJobRestServer extends ClusterJobRestServer {
@Inject
private readonly remoteMachineTrainingService : RemoteMachineTrainingService;
private readonly remoteMachineTrainingService: RemoteMachineTrainingService;
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
......@@ -25,8 +25,7 @@ export class RemoteMachineJobRestServer extends ClusterJobRestServer {
this.remoteMachineTrainingService = component.get(RemoteMachineTrainingService);
}
// tslint:disable-next-line:no-any
protected handleTrialMetrics(jobId : string, metrics : any[]) : void {
protected handleTrialMetrics(jobId: string, metrics: any[]): void {
// Split metrics array into single metric, then emit
// Warning: If not split metrics into single ones, the behavior will be UNKNOWNls
for (const singleMetric of metrics) {
......
......@@ -4,12 +4,10 @@
'use strict';
import * as assert from 'assert';
import * as cpp from 'child-process-promise';
import { EventEmitter } from 'events';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import { Client, ConnectConfig } from 'ssh2';
import { Client } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component';
......@@ -29,12 +27,12 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { GPUSummary } from '../common/gpuData';
import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execCopydir, execMkdir, execRemove, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util';
import { execCopydir, execMkdir, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util';
import { GPUScheduler } from './gpuScheduler';
import {
HOST_JOB_SHELL_FORMAT, RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta,
RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta,
RemoteMachineScheduleInfo, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail,
ScheduleResultType, SSHClient, SSHClientManager
ScheduleResultType, SSHClientManager
} from './remoteMachineData';
import { RemoteMachineJobRestServer } from './remoteMachineJobRestServer';
import { SSHClientUtility } from './sshClientUtility';
......@@ -93,7 +91,7 @@ class RemoteMachineTrainingService implements TrainingService {
while (this.jobQueue.length > 0) {
this.updateGpuReservation();
const trialJobId: string = this.jobQueue[0];
const prepareResult : boolean = await this.prepareTrialJob(trialJobId);
const prepareResult: boolean = await this.prepareTrialJob(trialJobId);
if (prepareResult) {
// Remove trial job with trialJobId from job queue
this.jobQueue.shift();
......@@ -208,7 +206,6 @@ class RemoteMachineTrainingService implements TrainingService {
* Submit trial job
* @param form trial job description form
*/
// tslint:disable-next-line:informative-docs
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized');
......@@ -241,12 +238,7 @@ class RemoteMachineTrainingService implements TrainingService {
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
const rmMeta: RemoteMachineMeta | undefined = (<RemoteMachineTrialJobDetail>trialJobDetail).rmMeta;
if (rmMeta !== undefined) {
await this.writeParameterFile(trialJobId, form.hyperParameters, rmMeta);
} else {
throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
}
await this.writeParameterFile(trialJobId, form.hyperParameters);
return trialJobDetail;
}
......@@ -262,7 +254,6 @@ class RemoteMachineTrainingService implements TrainingService {
* Cancel trial job
* @param trialJobId ID of trial job
*/
// tslint:disable:informative-docs no-unsafe-any
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const trialJob: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
......@@ -272,7 +263,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
// Remove the job with trialJobId from job queue
const index : number = this.jobQueue.indexOf(trialJobId);
const index: number = this.jobQueue.indexOf(trialJobId);
if (index >= 0) {
this.jobQueue.splice(index, 1);
}
......@@ -319,14 +310,13 @@ class RemoteMachineTrainingService implements TrainingService {
await this.setupConnections(value);
this.gpuScheduler = new GPUScheduler(this.machineSSHClientMap);
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
case TrialConfigMetadataKey.TRIAL_CONFIG: {
const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value);
// Parse trial config failed, throw Error
if (remoteMachineTrailConfig === undefined) {
throw new Error('trial config parsed failed');
}
// codeDir is not a valid directory, throw Error
// tslint:disable-next-line:non-literal-fs-path
if (!fs.lstatSync(remoteMachineTrailConfig.codeDir)
.isDirectory()) {
throw new Error(`codeDir ${remoteMachineTrailConfig.codeDir} is not a directory`);
......@@ -343,6 +333,7 @@ class RemoteMachineTrainingService implements TrainingService {
this.trialConfig = remoteMachineTrailConfig;
break;
}
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
......@@ -444,7 +435,6 @@ class RemoteMachineTrainingService implements TrainingService {
await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn);
//Begin to execute gpu_metrics_collection scripts
// tslint:disable-next-line: no-floating-promises
const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir);
SSHClientUtility.remoteExeCommand(`bash -c '${script}'`, conn);
......@@ -464,7 +454,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
private async prepareTrialJob(trialJobId: string): Promise<boolean> {
const deferred : Deferred<boolean> = new Deferred<boolean>();
const deferred: Deferred<boolean> = new Deferred<boolean>();
if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized');
......@@ -485,13 +475,13 @@ class RemoteMachineTrainingService implements TrainingService {
// get an ssh client from scheduler
const rmScheduleResult: RemoteMachineScheduleResult = this.gpuScheduler.scheduleMachine(this.trialConfig.gpuNum, trialJobDetail);
if (rmScheduleResult.resultType === ScheduleResultType.REQUIRE_EXCEED_TOTAL) {
const errorMessage : string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`;
const errorMessage: string = `Required GPU number ${this.trialConfig.gpuNum} is too large, no machine can meet`;
this.log.error(errorMessage);
deferred.reject();
throw new NNIError(NNIErrorNames.RESOURCE_NOT_AVAILABLE, errorMessage);
} else if (rmScheduleResult.resultType === ScheduleResultType.SUCCEED
&& rmScheduleResult.scheduleInfo !== undefined) {
const rmScheduleInfo : RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo;
const rmScheduleInfo: RemoteMachineScheduleInfo = rmScheduleResult.scheduleInfo;
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
trialJobDetail.rmMeta = rmScheduleInfo.rmMeta;
......@@ -521,7 +511,7 @@ class RemoteMachineTrainingService implements TrainingService {
if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized');
}
const cuda_visible_device: string = rmScheduleInfo.cuda_visible_device;
const cudaVisibleDevice: string = rmScheduleInfo.cudaVisibleDevice;
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId);
if (sshClient === undefined) {
assert(false, 'sshClient is undefined.');
......@@ -543,19 +533,18 @@ class RemoteMachineTrainingService implements TrainingService {
// See definition in remoteMachineData.ts
let command: string;
// Set CUDA_VISIBLE_DEVICES environment variable based on cuda_visible_device
// If no valid cuda_visible_device is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// Set CUDA_VISIBLE_DEVICES environment variable based on cudaVisibleDevice
// If no valid cudaVisibleDevice is defined, set CUDA_VISIBLE_DEVICES to empty string to hide GPU device
// If gpuNum is undefined, will not set CUDA_VISIBLE_DEVICES in script
if (this.trialConfig.gpuNum === undefined) {
command = this.trialConfig.command;
} else {
if (typeof cuda_visible_device === 'string' && cuda_visible_device.length > 0) {
command = `CUDA_VISIBLE_DEVICES=${cuda_visible_device} ${this.trialConfig.command}`;
if (typeof cudaVisibleDevice === 'string' && cudaVisibleDevice.length > 0) {
command = `CUDA_VISIBLE_DEVICES=${cudaVisibleDevice} ${this.trialConfig.command}`;
} else {
command = `CUDA_VISIBLE_DEVICES=" " ${this.trialConfig.command}`;
}
}
// tslint:disable-next-line: strict-boolean-expressions
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
if (this.remoteRestServerPort === undefined) {
const restServer: RemoteMachineJobRestServer = component.get(RemoteMachineJobRestServer);
......@@ -584,16 +573,15 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally.
await execCopydir(this.trialConfig.codeDir, trialLocalTempFolder);
const installScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
const installScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' });
// Write file content ( run.sh and parameter.cfg ) to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptTrialContent, { encoding: 'utf8' });
await this.writeParameterFile(trialJobId, form.hyperParameters, rmScheduleInfo.rmMeta);
await this.writeParameterFile(trialJobId, form.hyperParameters);
// Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(trialLocalTempFolder, trialWorkingFolder, sshClient, this.remoteOS);
// Execute command in remote machine
// tslint:disable-next-line: no-floating-promises
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient);
}
......@@ -610,6 +598,7 @@ class RemoteMachineTrainingService implements TrainingService {
const deferred: Deferred<TrialJobDetail> = new Deferred<TrialJobDetail>();
const jobpidPath: string = this.getJobPidPath(trialJob.id);
const trialReturnCodeFilePath: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJob.id, '.nni', 'code');
/* eslint-disable require-atomic-updates */
try {
const killResult: number = (await SSHClientUtility.remoteExeCommand(`kill -0 \`cat ${jobpidPath}\``, sshClient)).exitCode;
// if the process of jobpid is not alive any more
......@@ -646,7 +635,7 @@ class RemoteMachineTrainingService implements TrainingService {
deferred.resolve(trialJob);
}
}
/* eslint-enable require-atomic-updates */
return deferred.promise;
}
......@@ -662,7 +651,7 @@ class RemoteMachineTrainingService implements TrainingService {
return unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni', 'experiments', getExperimentId());
}
public get MetricsEmitter() : EventEmitter {
public get MetricsEmitter(): EventEmitter {
return this.metricsEmitter;
}
......@@ -672,13 +661,10 @@ class RemoteMachineTrainingService implements TrainingService {
throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${jobId}`);
}
let jobpidPath: string;
jobpidPath = unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
return jobpidPath;
return unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
}
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters, rmMeta: RemoteMachineMeta): Promise<void> {
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
const sshClient: Client | undefined = this.trialSSHClientMap.get(trialJobId);
if (sshClient === undefined) {
throw new Error('sshClient is undefined.');
......
......@@ -4,7 +4,6 @@
'use strict';
import * as assert from 'assert';
import * as cpp from 'child-process-promise';
import * as os from 'os';
import * as path from 'path';
import { Client, ClientChannel, SFTPWrapper } from 'ssh2';
......@@ -22,44 +21,18 @@ import { RemoteCommandResult } from './remoteMachineData';
*
*/
export namespace SSHClientUtility {
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
export async function copyDirectoryToRemote(localDirectory : string, remoteDirectory : string, sshClient : Client, remoteOS: string)
: Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const tmpTarName: string = `${uniqueString(10)}.tar.gz`;
const localTarPath: string = path.join(os.tmpdir(), tmpTarName);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), tmpTarName);
// Compress files in local directory to experiment root directory
await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it
await copyFileToRemote(localTarPath, remoteTarPath, sshClient);
await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it
await remoteExeCommand(`tar -oxzf ${remoteTarPath} -C ${remoteDirectory}`, sshClient);
await remoteExeCommand(`rm ${remoteTarPath}`, sshClient);
deferred.resolve();
return deferred.promise;
}
/**
* Copy local file to remote path
* @param localFilePath the path of local file
* @param remoteFilePath the target path in remote machine
* @param sshClient SSH Client
*/
export function copyFileToRemote(localFilePath : string, remoteFilePath : string, sshClient : Client) : Promise<boolean> {
export function copyFileToRemote(localFilePath: string, remoteFilePath: string, sshClient: Client): Promise<boolean> {
const log: Logger = getLogger();
log.debug(`copyFileToRemote: localFilePath: ${localFilePath}, remoteFilePath: ${remoteFilePath}`);
assert(sshClient !== undefined);
const deferred: Deferred<boolean> = new Deferred<boolean>();
sshClient.sftp((err : Error, sftp : SFTPWrapper) => {
sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) {
log.error(`copyFileToRemote: ${err.message}, ${localFilePath}, ${remoteFilePath}`);
deferred.reject(err);
......@@ -67,7 +40,7 @@ export namespace SSHClientUtility {
return;
}
assert(sftp !== undefined);
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr : Error) => {
sftp.fastPut(localFilePath, remoteFilePath, (fastPutErr: Error) => {
sftp.end();
if (fastPutErr !== undefined && fastPutErr !== null) {
deferred.reject(fastPutErr);
......@@ -85,16 +58,15 @@ export namespace SSHClientUtility {
* @param command the command to execute remotely
* @param client SSH Client
*/
// tslint:disable:no-unsafe-any no-any
export function remoteExeCommand(command : string, client : Client): Promise<RemoteCommandResult> {
export function remoteExeCommand(command: string, client: Client): Promise<RemoteCommandResult> {
const log: Logger = getLogger();
log.debug(`remoteExeCommand: command: [${command}]`);
const deferred : Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
const deferred: Deferred<RemoteCommandResult> = new Deferred<RemoteCommandResult>();
let stdout: string = '';
let stderr: string = '';
let exitCode : number;
let exitCode: number;
client.exec(command, (err : Error, channel : ClientChannel) => {
client.exec(command, (err: Error, channel: ClientChannel) => {
if (err !== undefined && err !== null) {
log.error(`remoteExeCommand: ${err.message}`);
deferred.reject(err);
......@@ -102,14 +74,14 @@ export namespace SSHClientUtility {
return;
}
channel.on('data', (data : any, dataStderr : any) => {
channel.on('data', (data: any, dataStderr: any) => {
if (dataStderr !== undefined && dataStderr !== null) {
stderr += data.toString();
} else {
stdout += data.toString();
}
})
.on('exit', (code : any, signal : any) => {
.on('exit', (code: any, signal: any) => {
exitCode = <number>code;
deferred.resolve({
stdout : stdout,
......@@ -122,9 +94,34 @@ export namespace SSHClientUtility {
return deferred.promise;
}
/**
* Copy files and directories in local directory recursively to remote directory
* @param localDirectory local diretory
* @param remoteDirectory remote directory
* @param sshClient SSH client
*/
export async function copyDirectoryToRemote(localDirectory: string, remoteDirectory: string, sshClient: Client, remoteOS: string): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const tmpTarName: string = `${uniqueString(10)}.tar.gz`;
const localTarPath: string = path.join(os.tmpdir(), tmpTarName);
const remoteTarPath: string = unixPathJoin(getRemoteTmpDir(remoteOS), tmpTarName);
// Compress files in local directory to experiment root directory
await tarAdd(localTarPath, localDirectory);
// Copy the compressed file to remoteDirectory and delete it
await copyFileToRemote(localTarPath, remoteTarPath, sshClient);
await execRemove(localTarPath);
// Decompress the remote compressed file in and delete it
await remoteExeCommand(`tar -oxzf ${remoteTarPath} -C ${remoteDirectory}`, sshClient);
await remoteExeCommand(`rm ${remoteTarPath}`, sshClient);
deferred.resolve();
return deferred.promise;
}
export function getRemoteFileContent(filePath: string, sshClient: Client): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
sshClient.sftp((err: Error, sftp : SFTPWrapper) => {
sshClient.sftp((err: Error, sftp: SFTPWrapper) => {
if (err !== undefined && err !== null) {
getLogger()
.error(`getRemoteFileContent: ${err.message}`);
......@@ -133,10 +130,10 @@ export namespace SSHClientUtility {
return;
}
try {
const sftpStream : stream.Readable = sftp.createReadStream(filePath);
const sftpStream: stream.Readable = sftp.createReadStream(filePath);
let dataBuffer: string = '';
sftpStream.on('data', (data : Buffer | string) => {
sftpStream.on('data', (data: Buffer | string) => {
dataBuffer += data;
})
.on('error', (streamErr: Error) => {
......@@ -158,5 +155,4 @@ export namespace SSHClientUtility {
return deferred.promise;
}
// tslint:enable:no-unsafe-any no-any
}
......@@ -703,7 +703,7 @@ buffer-stream-reader@^0.1.1:
version "0.1.1"
resolved "https://registry.yarnpkg.com/buffer-stream-reader/-/buffer-stream-reader-0.1.1.tgz#ca8bf93631deedd8b8f8c3bb44991cc30951e259"
builtin-modules@^1.0.0, builtin-modules@^1.1.1:
builtin-modules@^1.0.0:
version "1.1.1"
resolved "https://registry.yarnpkg.com/builtin-modules/-/builtin-modules-1.1.1.tgz#270f076c5a72c02f5b65a47df94c5fe3a278892f"
......@@ -841,7 +841,7 @@ chalk@^1.0.0:
strip-ansi "^3.0.0"
supports-color "^2.0.0"
chalk@^2.0.0, chalk@^2.3.0:
chalk@^2.0.0:
version "2.4.1"
resolved "https://registry.yarnpkg.com/chalk/-/chalk-2.4.1.tgz#18c49ab16a037b6eb0152cc83e3471338215b66e"
dependencies:
......@@ -971,10 +971,6 @@ commander@2.15.1:
version "2.15.1"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.15.1.tgz#df46e867d0fc2aec66a34662b406a9ccafff5b0f"
commander@^2.12.1:
version "2.16.0"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.16.0.tgz#f16390593996ceb4f3eeb020b31d78528f7f8a50"
commander@~2.17.1:
version "2.17.1"
resolved "https://registry.yarnpkg.com/commander/-/commander-2.17.1.tgz#bd77ab7de6de94205ceacc72f1716d29f20a77bf"
......@@ -1134,7 +1130,7 @@ debug@^4.0.1, debug@^4.1.0, debug@^4.1.1:
dependencies:
ms "^2.1.1"
debuglog@*, debuglog@^1.0.1:
debuglog@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/debuglog/-/debuglog-1.0.1.tgz#aa24ffb9ac3df9a2351837cfb2d279360cd78492"
......@@ -1217,7 +1213,7 @@ dezalgo@^1.0.0, dezalgo@~1.0.3:
asap "^2.0.0"
wrappy "1"
diff@3.5.0, diff@^3.1.0, diff@^3.2.0:
diff@3.5.0, diff@^3.1.0:
version "3.5.0"
resolved "https://registry.yarnpkg.com/diff/-/diff-3.5.0.tgz#800c0dd1e0a8bfbc95835c202ad220fe317e5a12"
......@@ -2080,7 +2076,7 @@ import-lazy@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/import-lazy/-/import-lazy-2.1.0.tgz#05698e3d45c88e8d7e9d92cb0584e77f096f3e43"
imurmurhash@*, imurmurhash@^0.1.4:
imurmurhash@^0.1.4:
version "0.1.4"
resolved "https://registry.yarnpkg.com/imurmurhash/-/imurmurhash-0.1.4.tgz#9218b9b2b928a238b13dc4fb6b6d576f231453ea"
......@@ -2519,10 +2515,6 @@ lockfile@~1.0.3:
dependencies:
signal-exit "^3.0.2"
lodash._baseindexof@*:
version "3.1.0"
resolved "https://registry.yarnpkg.com/lodash._baseindexof/-/lodash._baseindexof-3.1.0.tgz#fe52b53a1c6761e42618d654e4a25789ed61822c"
lodash._baseuniq@~4.6.0:
version "4.6.0"
resolved "https://registry.yarnpkg.com/lodash._baseuniq/-/lodash._baseuniq-4.6.0.tgz#0ebb44e456814af7905c6212fa2c9b2d51b841e8"
......@@ -2530,28 +2522,10 @@ lodash._baseuniq@~4.6.0:
lodash._createset "~4.0.0"
lodash._root "~3.0.0"
lodash._bindcallback@*:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._bindcallback/-/lodash._bindcallback-3.0.1.tgz#e531c27644cf8b57a99e17ed95b35c748789392e"
lodash._cacheindexof@*:
version "3.0.2"
resolved "https://registry.yarnpkg.com/lodash._cacheindexof/-/lodash._cacheindexof-3.0.2.tgz#3dc69ac82498d2ee5e3ce56091bafd2adc7bde92"
lodash._createcache@*:
version "3.1.2"
resolved "https://registry.yarnpkg.com/lodash._createcache/-/lodash._createcache-3.1.2.tgz#56d6a064017625e79ebca6b8018e17440bdcf093"
dependencies:
lodash._getnative "^3.0.0"
lodash._createset@~4.0.0:
version "4.0.3"
resolved "https://registry.yarnpkg.com/lodash._createset/-/lodash._createset-4.0.3.tgz#0f4659fbb09d75194fa9e2b88a6644d363c9fe26"
lodash._getnative@*, lodash._getnative@^3.0.0:
version "3.9.1"
resolved "https://registry.yarnpkg.com/lodash._getnative/-/lodash._getnative-3.9.1.tgz#570bc7dede46d61cdcde687d65d3eecbaa3aaff5"
lodash._root@~3.0.0:
version "3.0.1"
resolved "https://registry.yarnpkg.com/lodash._root/-/lodash._root-3.0.1.tgz#fba1c4524c19ee9a5f8136b4609f017cf4ded692"
......@@ -2600,10 +2574,6 @@ lodash.pick@^4.4.0:
version "4.4.0"
resolved "https://registry.yarnpkg.com/lodash.pick/-/lodash.pick-4.4.0.tgz#52f05610fff9ded422611441ed1fc123a03001b3"
lodash.restparam@*:
version "3.6.1"
resolved "https://registry.yarnpkg.com/lodash.restparam/-/lodash.restparam-3.6.1.tgz#936a4e309ef330a7645ed4145986c85ae5b20805"
lodash.unescape@4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/lodash.unescape/-/lodash.unescape-4.0.1.tgz#bf2249886ce514cda112fae9218cdc065211fc9c"
......@@ -3519,10 +3489,6 @@ path-key@^2.0.0, path-key@^2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/path-key/-/path-key-2.0.1.tgz#411cadb574c5a140d3a4b1910d40d80cc9f40b40"
path-parse@^1.0.5:
version "1.0.5"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.5.tgz#3c1adf871ea9cd6c9431b6ea2bd74a0ff055c4c1"
path-parse@^1.0.6:
version "1.0.6"
resolved "https://registry.yarnpkg.com/path-parse/-/path-parse-1.0.6.tgz#d62dbb5679405d72c4737ec58600e9ddcf06d24c"
......@@ -3834,7 +3800,7 @@ readable-stream@~2.0.0:
string_decoder "~0.10.x"
util-deprecate "~1.0.1"
readdir-scoped-modules@*, readdir-scoped-modules@^1.0.0:
readdir-scoped-modules@^1.0.0:
version "1.1.0"
resolved "https://registry.yarnpkg.com/readdir-scoped-modules/-/readdir-scoped-modules-1.1.0.tgz#8d45407b4f870a0dcaebc0e28670d18e74514309"
dependencies:
......@@ -3977,12 +3943,6 @@ resolve@^1.10.0:
dependencies:
path-parse "^1.0.6"
resolve@^1.3.2:
version "1.8.1"
resolved "https://registry.yarnpkg.com/resolve/-/resolve-1.8.1.tgz#82f1ec19a423ac1fbd080b0bab06ba36e84a7a26"
dependencies:
path-parse "^1.0.5"
responselike@1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/responselike/-/responselike-1.0.2.tgz#918720ef3b631c5642be068f15ade5a46f4ba1e7"
......@@ -4599,7 +4559,7 @@ ts-node@^7.0.0:
source-map-support "^0.5.6"
yn "^2.0.0"
tslib@^1.8.0, tslib@^1.8.1:
tslib@^1.8.1:
version "1.9.3"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.9.3.tgz#d7e4dd79245d85428c4d7e4822a79917954ca286"
......@@ -4607,42 +4567,6 @@ tslib@^1.9.0:
version "1.10.0"
resolved "https://registry.yarnpkg.com/tslib/-/tslib-1.10.0.tgz#c3c19f95973fb0a62973fb09d90d961ee43e5c8a"
tslint-microsoft-contrib@^6.0.0:
version "6.2.0"
resolved "https://registry.yarnpkg.com/tslint-microsoft-contrib/-/tslint-microsoft-contrib-6.2.0.tgz#8aa0f40584d066d05e6a5e7988da5163b85f2ad4"
dependencies:
tsutils "^2.27.2 <2.29.0"
tslint@^5.12.0:
version "5.18.0"
resolved "https://registry.yarnpkg.com/tslint/-/tslint-5.18.0.tgz#f61a6ddcf372344ac5e41708095bbf043a147ac6"
dependencies:
"@babel/code-frame" "^7.0.0"
builtin-modules "^1.1.1"
chalk "^2.3.0"
commander "^2.12.1"
diff "^3.2.0"
glob "^7.1.1"
js-yaml "^3.13.1"
minimatch "^3.0.4"
mkdirp "^0.5.1"
resolve "^1.3.2"
semver "^5.3.0"
tslib "^1.8.0"
tsutils "^2.29.0"
"tsutils@^2.27.2 <2.29.0":
version "2.28.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.28.0.tgz#6bd71e160828f9d019b6f4e844742228f85169a1"
dependencies:
tslib "^1.8.1"
tsutils@^2.29.0:
version "2.29.0"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-2.29.0.tgz#32b488501467acbedd4b85498673a0812aca0b99"
dependencies:
tslib "^1.8.1"
tsutils@^3.17.1:
version "3.17.1"
resolved "https://registry.yarnpkg.com/tsutils/-/tsutils-3.17.1.tgz#ed719917f11ca0dee586272b2ac49e015a2dd759"
......@@ -4818,7 +4742,7 @@ v8-compile-cache@^2.0.3:
version "2.1.0"
resolved "https://registry.yarnpkg.com/v8-compile-cache/-/v8-compile-cache-2.1.0.tgz#e14de37b31a6d194f5690d67efc4e7f6fc6ab30e"
validate-npm-package-license@*, validate-npm-package-license@^3.0.1:
validate-npm-package-license@^3.0.1:
version "3.0.4"
resolved "https://registry.yarnpkg.com/validate-npm-package-license/-/validate-npm-package-license-3.0.4.tgz#fc91f6b9c7ba15c857f4cb2c5defeec39d4f410a"
dependencies:
......
......@@ -24,13 +24,15 @@ class BatchTuner(Tuner):
Examples
--------
The search space only be accepted like:
```
{
'combine_params': { '_type': 'choice',
'_value': '[{...}, {...}, {...}]',
}
}
```
::
{'combine_params':
{ '_type': 'choice',
'_value': '[{...}, {...}, {...}]',
}
}
"""
def __init__(self):
......
......@@ -5,7 +5,7 @@ import logging
import torch
from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner', 'L1FilterPruner', 'SlimPruner']
__all__ = ['LevelPruner', 'AGP_Pruner', 'SlimPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
logger = logging.getLogger('torch pruner')
......@@ -166,119 +166,132 @@ class AGP_Pruner(Pruner):
self.if_init_list[k] = True
class FPGMPruner(Pruner):
class SlimPruner(Pruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_dict = {}
self.epoch_pruned_layers = set()
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config):
"""
Supports Conv1d, Conv2d
filter dimensions for Conv1d:
OUT: number of output channel
IN: number of input channel
LEN: filter length
filter dimensions for Conv2d:
OUT: number of output channel
IN: number of input channel
H: filter height
W: filter width
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
calculate mask for `layer`'s weight
the layer to instrument the compression operation
config : dict
the configuration for generating the mask
layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
"""
weight = layer.module.weight.data
assert 0 <= config.get('sparsity') < 1
assert layer.type in ['Conv1d', 'Conv2d']
assert layer.type in config['op_types']
op_name = layer.name
op_type = layer.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
if layer.name in self.epoch_pruned_layers:
assert layer.name in self.mask_dict
return self.mask_dict.get(layer.name)
return mask
masks = torch.ones(weight.size()).type_as(weight)
try:
num_filters = weight.size(0)
num_prune = int(num_filters * config.get('sparsity'))
if num_filters < 2 or num_prune < 1:
return masks
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
masks[idx] = 0.
finally:
self.mask_dict.update({layer.name: masks})
self.epoch_pruned_layers.add(layer.name)
class RankFilterPruner(Pruner):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""
return masks
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
super().__init__(model, config_list)
self.mask_calculated_ops = set()
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_mask(self, base_mask, weight, num_prune):
return torch.ones(weight.size()).type_as(weight)
def _get_distance_sum(self, weight, out_idx):
def calc_mask(self, layer, config):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
float32
The total distance
torch.Tensor
mask of the layer's weight
"""
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch):
self.epoch_pruned_layers = set()
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert 0 <= config.get('sparsity') < 1
assert op_type in ['Conv1d', 'Conv2d']
assert op_type in config.get('op_types')
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.size(0)
num_prune = int(filters * config.get('sparsity'))
if filters < 2 or num_prune < 1:
return mask
mask = self._get_mask(mask, weight, num_prune)
finally:
self.mask_dict.update({op_name: mask})
self.mask_calculated_ops.add(op_name)
return mask.detach()
class L1FilterPruner(Pruner):
class L1FilterPruner(RankFilterPruner):
"""
A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity.
......@@ -299,107 +312,162 @@ class L1FilterPruner(Pruner):
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
def calc_mask(self, layer, config):
def _get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
mask of the layer's weight
Mask of the layer's weight
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == 'Conv2d', 'L1FilterPruner only supports 2d convolution layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.shape[0]
w_abs = weight.abs()
k = int(filters * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
filters = weight.shape[0]
w_abs = weight.abs()
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
return mask
class SlimPruner(Pruner):
class L2FilterPruner(RankFilterPruner):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
A structured pruning algorithm that prunes the filters with the
smallest L2 norm of the absolute kernel weights are masked.
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.abs().clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False)[0].max()
def calc_mask(self, layer, config):
def _get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Filters with the smallest L2 norm of the absolute kernel weights are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
mask of the layer's weight
Mask of the layer's weight
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
filters = weight.shape[0]
w = weight.view(filters, -1)
w_l2_norm = torch.sqrt((w ** 2).sum(dim=1))
threshold = torch.topk(w_l2_norm.view(-1), num_prune, largest=False)[0].max()
mask = torch.gt(w_l2_norm, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
return mask
class FPGMPruner(RankFilterPruner):
"""
A filter pruner via geometric median.
"Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration",
https://arxiv.org/pdf/1811.00250.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : pytorch model
the model user wants to compress
config_list: list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
def _get_mask(self, base_mask, weight, num_prune):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
base_mask : torch.Tensor
The basic mask with the same shape of weight, all item in the basic mask is 1.
weight : torch.Tensor
Layer's weight
num_prune : int
Num of filters to prune
Returns
-------
torch.Tensor
Mask of the layer's weight
"""
min_gm_idx = self._get_min_gm_kernel_idx(weight, num_prune)
for idx in min_gm_idx:
base_mask[idx] = 0.
return base_mask
def _get_min_gm_kernel_idx(self, weight, n):
assert len(weight.size()) in [3, 4]
dist_list = []
for out_i in range(weight.size(0)):
dist_sum = self._get_distance_sum(weight, out_i)
dist_list.append((dist_sum, out_i))
min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
return [x[1] for x in min_gm_kernels]
def _get_distance_sum(self, weight, out_idx):
"""
Calculate the total distance between a specified filter (by out_idex and in_idx) and
all other filters.
Optimized verision of following naive implementation:
def _get_distance_sum(self, weight, in_idx, out_idx):
w = weight.view(-1, weight.size(-2), weight.size(-1))
dist_sum = 0.
for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum
Parameters
----------
weight: Tensor
convolutional filter weight
out_idx: int
output channel index of specified filter, this method calculates the total distance
between this specified filter and all other filters.
Returns
-------
float32
The total distance
"""
logger.debug('weight size: %s', weight.size())
assert len(weight.size()) in [3, 4], 'unsupported weight shape'
w = weight.view(weight.size(0), -1)
anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1))
x = w - anchor_w
x = (x * x).sum(-1)
x = torch.sqrt(x)
return x.sum()
def update_epoch(self, epoch):
self.mask_calculated_ops = set()
......@@ -163,19 +163,23 @@ class MsgDispatcherBase(Recoverable):
raise NotImplementedError('handle_initialize not implemented')
def handle_request_trial_jobs(self, data):
"""The message dispatcher is demanded to generate `data` trial jobs.
These trial jobs should be sent via `send(CommandType.NewTrialJob, json_tricks.dumps(parameter))`,
where `parameter` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this `send` exactly `data` times.
"""The message dispatcher is demanded to generate ``data`` trial jobs.
These trial jobs should be sent via ``send(CommandType.NewTrialJob, json_tricks.dumps(parameter))``,
where ``parameter`` will be received by NNI Manager and eventually accessible to trial jobs as "next parameter".
Semantically, message dispatcher should do this ``send`` exactly ``data`` times.
The JSON sent by this method should follow the format of
{
"parameter_id": 42
"parameters": {
// this will be received by trial
},
"parameter_source": "algorithm" // optional
}
::
{
"parameter_id": 42
"parameters": {
// this will be received by trial
},
"parameter_source": "algorithm" // optional
}
Parameters
----------
data: int
......@@ -211,6 +215,7 @@ class MsgDispatcherBase(Recoverable):
def handle_report_metric_data(self, data):
"""Called when metric data is reported or new parameters are requested (for multiphase).
When new parameters are requested, this method should send a new parameter.
Parameters
----------
data: dict
......@@ -219,6 +224,7 @@ class MsgDispatcherBase(Recoverable):
`REQUEST_PARAMETER` is used to request new parameters for multiphase trial job. In this case,
the dict will contain additional keys: `trial_job_id`, `parameter_index`. Refer to `msg_dispatcher.py`
as an example.
Raises
------
ValueError
......@@ -228,6 +234,7 @@ class MsgDispatcherBase(Recoverable):
def handle_trial_end(self, data):
"""Called when the state of one of the trials is changed
Parameters
----------
data: dict
......@@ -235,5 +242,6 @@ class MsgDispatcherBase(Recoverable):
trial_job_id: the id generated by training service.
event: the job’s state.
hyper_params: the string that is sent by message dispatcher during the creation of trials.
"""
raise NotImplementedError('handle_trial_end not implemented')
......@@ -58,8 +58,9 @@ def tf2(func):
return test_tf2_func
# for fpgm filter pruner test
w = np.array([[[[i+1]*3]*3]*5 for i in range(10)])
w = np.array([[[[i + 1] * 3] * 3] * 5 for i in range(10)])
class CompressorTestCase(TestCase):
......@@ -69,19 +70,19 @@ class CompressorTestCase(TestCase):
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types':['Conv2d', 'Linear']
'op_types': ['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 0,
'op_types':['ReLU']
'op_types': ['ReLU']
}]
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
quantizer.compress()
modules_to_compress = quantizer.get_modules_to_compress()
modules_to_compress_name = [ t[0].name for t in modules_to_compress]
modules_to_compress_name = [t[0].name for t in modules_to_compress]
assert "conv1" in modules_to_compress_name
assert "conv2" in modules_to_compress_name
assert "fc1" in modules_to_compress_name
......@@ -179,7 +180,8 @@ class CompressorTestCase(TestCase):
w = np.array([np.zeros((3, 3, 3)), np.ones((3, 3, 3)), np.ones((3, 3, 3)) * 2,
np.ones((3, 3, 3)) * 3, np.ones((3, 3, 3)) * 4])
model = TorchModel()
config_list = [{'sparsity': 0.2, 'op_names': ['conv1']}, {'sparsity': 0.6, 'op_names': ['conv2']}]
config_list = [{'sparsity': 0.2, 'op_types': ['Conv2d'], 'op_names': ['conv1']},
{'sparsity': 0.6, 'op_types': ['Conv2d'], 'op_names': ['conv2']}]
pruner = torch_compressor.L1FilterPruner(model, config_list)
model.conv1.weight.data = torch.tensor(w).float()
......@@ -236,12 +238,12 @@ class CompressorTestCase(TestCase):
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types':['Conv2d', 'Linear']
'op_types': ['Conv2d', 'Linear']
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 0,
'op_types':['ReLU']
'op_types': ['ReLU']
}]
model.relu = torch.nn.ReLU()
quantizer = torch_compressor.QAT_Quantizer(model, config_list)
......@@ -253,7 +255,7 @@ class CompressorTestCase(TestCase):
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 5 / 255, abs_tol=eps)
assert model.conv2.zero_point == 0
# range including 0
# range including 0
weight = torch.tensor([[-1, 2], [3, 5]]).float()
quantize_weight = quantizer.quantize_weight(weight, config_list[0], model.conv2)
assert math.isclose(model.conv2.scale, 6 / 255, abs_tol=eps)
......@@ -271,5 +273,6 @@ class CompressorTestCase(TestCase):
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment