Commit ea665155 authored by quzha's avatar quzha
Browse files

Merge branch 'master' of github.com:Microsoft/nni into dev-nas-refactor

parents 73b2221b ae36373c
...@@ -155,11 +155,7 @@ class Logger { ...@@ -155,11 +155,7 @@ class Logger {
} }
} }
function getLogger(fileName?: string): Logger { function getLogger(): Logger {
component.Container.bind(Logger).provider({
get: (): Logger => new Logger(fileName)
});
return component.get(Logger); return component.get(Logger);
} }
......
...@@ -105,7 +105,7 @@ abstract class Manager { ...@@ -105,7 +105,7 @@ abstract class Manager {
public abstract importData(data: string): Promise<void>; public abstract importData(data: string): Promise<void>;
public abstract exportData(): Promise<string>; public abstract exportData(): Promise<string>;
public abstract addCustomizedTrialJob(hyperParams: string): Promise<void>; public abstract addCustomizedTrialJob(hyperParams: string): Promise<number>;
public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>; public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>;
public abstract listTrialJobs(status?: TrialJobStatus): Promise<TrialJobInfo[]>; public abstract listTrialJobs(status?: TrialJobStatus): Promise<TrialJobInfo[]>;
......
...@@ -58,11 +58,6 @@ interface TrialJobDetail { ...@@ -58,11 +58,6 @@ interface TrialJobDetail {
isEarlyStopped?: boolean; isEarlyStopped?: boolean;
} }
interface HostJobDetail {
readonly id: string;
readonly status: string;
}
/** /**
* define TrialJobMetric * define TrialJobMetric
*/ */
......
...@@ -50,13 +50,12 @@ class NNIManager implements Manager { ...@@ -50,13 +50,12 @@ class NNIManager implements Manager {
private dispatcher: IpcInterface | undefined; private dispatcher: IpcInterface | undefined;
private currSubmittedTrialNum: number; // need to be recovered private currSubmittedTrialNum: number; // need to be recovered
private trialConcurrencyChange: number; // >0: increase, <0: decrease private trialConcurrencyChange: number; // >0: increase, <0: decrease
private customizedTrials: string[]; // need to be recovered
private log: Logger; private log: Logger;
private dataStore: DataStore; private dataStore: DataStore;
private experimentProfile: ExperimentProfile; private experimentProfile: ExperimentProfile;
private dispatcherPid: number; private dispatcherPid: number;
private status: NNIManagerStatus; private status: NNIManagerStatus;
private waitingTrials: string[]; private waitingTrials: TrialJobApplicationForm[];
private trialJobs: Map<string, TrialJobDetail>; private trialJobs: Map<string, TrialJobDetail>;
private trialDataForTuner: string; private trialDataForTuner: string;
private readonly: boolean; private readonly: boolean;
...@@ -66,7 +65,6 @@ class NNIManager implements Manager { ...@@ -66,7 +65,6 @@ class NNIManager implements Manager {
constructor() { constructor() {
this.currSubmittedTrialNum = 0; this.currSubmittedTrialNum = 0;
this.trialConcurrencyChange = 0; this.trialConcurrencyChange = 0;
this.customizedTrials = [];
this.trainingService = component.get(TrainingService); this.trainingService = component.get(TrainingService);
assert(this.trainingService); assert(this.trainingService);
this.dispatcherPid = 0; this.dispatcherPid = 0;
...@@ -131,19 +129,34 @@ class NNIManager implements Manager { ...@@ -131,19 +129,34 @@ class NNIManager implements Manager {
return this.dataStore.exportTrialHpConfigs(); return this.dataStore.exportTrialHpConfigs();
} }
public addCustomizedTrialJob(hyperParams: string): Promise<void> { public addCustomizedTrialJob(hyperParams: string): Promise<number> {
if (this.readonly) { if (this.readonly) {
return Promise.reject(new Error('Error: can not add customized trial job in readonly mode!')); return Promise.reject(new Error('Error: can not add customized trial job in readonly mode!'));
} }
if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
return Promise.reject( return Promise.reject(new Error('reach maxTrialNum'));
new Error('reach maxTrialNum')
);
} }
this.customizedTrials.push(hyperParams);
// TODO: NNI manager should not peek tuner's internal protocol, let's refactor this later
const packedParameter = {
parameter_id: null,
parameter_source: 'customized',
parameters: JSON.parse(hyperParams)
}
const form: TrialJobApplicationForm = {
sequenceId: this.experimentProfile.nextSequenceId++,
hyperParameters: {
value: JSON.stringify(packedParameter),
index: 0
}
};
this.waitingTrials.push(form);
// trial id has not been generated yet, thus use '' instead // trial id has not been generated yet, thus use '' instead
return this.dataStore.storeTrialJobEvent('ADD_CUSTOMIZED', '', hyperParams); this.dataStore.storeTrialJobEvent('ADD_CUSTOMIZED', '', hyperParams);
return Promise.resolve(form.sequenceId);
} }
public async cancelTrialJobByUser(trialJobId: string): Promise<void> { public async cancelTrialJobByUser(trialJobId: string): Promise<void> {
...@@ -560,18 +573,7 @@ class NNIManager implements Manager { ...@@ -560,18 +573,7 @@ class NNIManager implements Manager {
this.trialConcurrencyChange = requestTrialNum; this.trialConcurrencyChange = requestTrialNum;
} }
const requestCustomTrialNum: number = Math.min(requestTrialNum, this.customizedTrials.length); this.requestTrialJobs(requestTrialNum);
for (let i: number = 0; i < requestCustomTrialNum; i++) {
// ask tuner for more trials
if (this.customizedTrials.length > 0) {
const hyperParams: string | undefined = this.customizedTrials.shift();
this.dispatcher.sendCommand(ADD_CUSTOMIZED_TRIAL_JOB, hyperParams);
}
}
if (requestTrialNum - requestCustomTrialNum > 0) {
this.requestTrialJobs(requestTrialNum - requestCustomTrialNum);
}
// check maxtrialnum and maxduration here // check maxtrialnum and maxduration here
// NO_MORE_TRIAL is more like a subset of RUNNING, because during RUNNING tuner // NO_MORE_TRIAL is more like a subset of RUNNING, because during RUNNING tuner
...@@ -609,26 +611,16 @@ class NNIManager implements Manager { ...@@ -609,26 +611,16 @@ class NNIManager implements Manager {
this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
break; break;
} }
const hyperParams: string | undefined = this.waitingTrials.shift(); const form = this.waitingTrials.shift() as TrialJobApplicationForm;
if (hyperParams === undefined) {
throw new Error(`Error: invalid hyper-parameters for job submission: ${hyperParams}`);
}
this.currSubmittedTrialNum++; this.currSubmittedTrialNum++;
const trialJobAppForm: TrialJobApplicationForm = { this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
sequenceId: this.experimentProfile.nextSequenceId++, const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(form);
hyperParameters: {
value: hyperParams,
index: 0
}
};
this.log.info(`submitTrialJob: form: ${JSON.stringify(trialJobAppForm)}`);
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm);
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.trialJobs.set(trialJobDetail.id, Object.assign({}, trialJobDetail)); this.trialJobs.set(trialJobDetail.id, Object.assign({}, trialJobDetail));
const trialJobDetailSnapshot: TrialJobDetail | undefined = this.trialJobs.get(trialJobDetail.id); const trialJobDetailSnapshot: TrialJobDetail | undefined = this.trialJobs.get(trialJobDetail.id);
if (trialJobDetailSnapshot != undefined) { if (trialJobDetailSnapshot != undefined) {
await this.dataStore.storeTrialJobEvent( await this.dataStore.storeTrialJobEvent(
trialJobDetailSnapshot.status, trialJobDetailSnapshot.id, hyperParams, trialJobDetailSnapshot); trialJobDetailSnapshot.status, trialJobDetailSnapshot.id, form.hyperParameters.value, trialJobDetailSnapshot);
} else { } else {
assert(false, `undefined trialJobDetail in trialJobs: ${trialJobDetail.id}`); assert(false, `undefined trialJobDetail in trialJobs: ${trialJobDetail.id}`);
} }
...@@ -734,7 +726,14 @@ class NNIManager implements Manager { ...@@ -734,7 +726,14 @@ class NNIManager implements Manager {
this.log.warning('It is not supposed to receive more trials after NO_MORE_TRIAL is set'); this.log.warning('It is not supposed to receive more trials after NO_MORE_TRIAL is set');
this.setStatus('RUNNING'); this.setStatus('RUNNING');
} }
this.waitingTrials.push(content); const form: TrialJobApplicationForm = {
sequenceId: this.experimentProfile.nextSequenceId++,
hyperParameters: {
value: content,
index: 0
}
};
this.waitingTrials.push(form);
break; break;
case SEND_TRIAL_JOB_PARAMETER: case SEND_TRIAL_JOB_PARAMETER:
const tunerCommand: any = JSON.parse(content); const tunerCommand: any = JSON.parse(content);
......
...@@ -121,7 +121,7 @@ describe('Unit test for nnimanager', function () { ...@@ -121,7 +121,7 @@ describe('Unit test for nnimanager', function () {
it('test addCustomizedTrialJob', () => { it('test addCustomizedTrialJob', () => {
return nniManager.addCustomizedTrialJob('hyperParams').then(() => { return nniManager.addCustomizedTrialJob('"hyperParams"').then(() => {
}).catch((error) => { }).catch((error) => {
assert.fail(error); assert.fail(error);
...@@ -273,7 +273,7 @@ describe('Unit test for nnimanager', function () { ...@@ -273,7 +273,7 @@ describe('Unit test for nnimanager', function () {
it('test addCustomizedTrialJob reach maxTrialNum', () => { it('test addCustomizedTrialJob reach maxTrialNum', () => {
// test currSubmittedTrialNum reach maxTrialNum // test currSubmittedTrialNum reach maxTrialNum
return nniManager.addCustomizedTrialJob('hyperParam').then(() => { return nniManager.addCustomizedTrialJob('"hyperParam"').then(() => {
nniManager.getTrialJobStatistics().then(function (trialJobStatistics) { nniManager.getTrialJobStatistics().then(function (trialJobStatistics) {
if (trialJobStatistics[0].trialJobStatus === 'WAITING') if (trialJobStatistics[0].trialJobStatus === 'WAITING')
expect(trialJobStatistics[0].trialJobNumber).to.be.equal(2); expect(trialJobStatistics[0].trialJobNumber).to.be.equal(2);
......
...@@ -49,7 +49,7 @@ function initStartupInfo( ...@@ -49,7 +49,7 @@ function initStartupInfo(
setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly); setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly);
} }
async function initContainer(platformMode: string): Promise<void> { async function initContainer(platformMode: string, logFileName?: string): Promise<void> {
if (platformMode === 'local') { if (platformMode === 'local') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(LocalTrainingService) .to(LocalTrainingService)
...@@ -82,6 +82,9 @@ async function initContainer(platformMode: string): Promise<void> { ...@@ -82,6 +82,9 @@ async function initContainer(platformMode: string): Promise<void> {
Container.bind(DataStore) Container.bind(DataStore)
.to(NNIDataStore) .to(NNIDataStore)
.scope(Scope.Singleton); .scope(Scope.Singleton);
Container.bind(Logger).provider({
get: (): Logger => new Logger(logFileName)
});
const ds: DataStore = component.get(DataStore); const ds: DataStore = component.get(DataStore);
await ds.init(); await ds.init();
...@@ -145,13 +148,14 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly); ...@@ -145,13 +148,14 @@ initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly);
mkDirP(getLogDir()) mkDirP(getLogDir())
.then(async () => { .then(async () => {
const log: Logger = getLogger();
try { try {
await initContainer(mode); await initContainer(mode);
const restServer: NNIRestServer = component.get(NNIRestServer); const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.start(); await restServer.start();
const log: Logger = getLogger();
log.info(`Rest server listening on: ${restServer.endPoint}`); log.info(`Rest server listening on: ${restServer.endPoint}`);
} catch (err) { } catch (err) {
const log: Logger = getLogger();
log.error(`${err.stack}`); log.error(`${err.stack}`);
throw err; throw err;
} }
......
...@@ -236,8 +236,8 @@ class NNIRestHandler { ...@@ -236,8 +236,8 @@ class NNIRestHandler {
private addTrialJob(router: Router): void { private addTrialJob(router: Router): void {
router.post('/trial-jobs', async (req: Request, res: Response) => { router.post('/trial-jobs', async (req: Request, res: Response) => {
this.nniManager.addCustomizedTrialJob(JSON.stringify(req.body)).then(() => { this.nniManager.addCustomizedTrialJob(JSON.stringify(req.body)).then((sequenceId: number) => {
res.send(); res.send({sequenceId});
}).catch((err: Error) => { }).catch((err: Error) => {
this.handle_error(err, res); this.handle_error(err, res);
}); });
......
...@@ -65,8 +65,8 @@ export class MockedNNIManager extends Manager { ...@@ -65,8 +65,8 @@ export class MockedNNIManager extends Manager {
return deferred.promise; return deferred.promise;
} }
public addCustomizedTrialJob(hyperParams: string): Promise<void> { public addCustomizedTrialJob(hyperParams: string): Promise<number> {
return Promise.resolve(); return Promise.resolve(99);
} }
public resumeExperiment(): Promise<void> { public resumeExperiment(): Promise<void> {
......
...@@ -59,14 +59,6 @@ export class GPUSummary { ...@@ -59,14 +59,6 @@ export class GPUSummary {
} }
} }
export const GPU_INFO_COLLECTOR_FORMAT_LINUX: string =
`
#!/bin/bash
export METRIC_OUTPUT_DIR={0}
echo $$ >{1}
python3 -m nni_gpu_tool.gpu_metrics_collector
`;
export const GPU_INFO_COLLECTOR_FORMAT_WINDOWS: string = export const GPU_INFO_COLLECTOR_FORMAT_WINDOWS: string =
` `
$env:METRIC_OUTPUT_DIR="{0}" $env:METRIC_OUTPUT_DIR="{0}"
......
...@@ -27,7 +27,7 @@ import * as path from 'path'; ...@@ -27,7 +27,7 @@ import * as path from 'path';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { countFilesRecursively, getNewLine, validateFileNameRecursively } from '../../common/utils'; import { countFilesRecursively, getNewLine, validateFileNameRecursively } from '../../common/utils';
import { file } from '../../node_modules/@types/tmp'; import { file } from '../../node_modules/@types/tmp';
import { GPU_INFO_COLLECTOR_FORMAT_LINUX, GPU_INFO_COLLECTOR_FORMAT_WINDOWS } from './gpuData'; import { GPU_INFO_COLLECTOR_FORMAT_WINDOWS } from './gpuData';
/** /**
* Validate codeDir, calculate file count recursively under codeDir, and throw error if any rule is broken * Validate codeDir, calculate file count recursively under codeDir, and throw error if any rule is broken
...@@ -89,7 +89,7 @@ export async function execCopydir(source: string, destination: string): Promise< ...@@ -89,7 +89,7 @@ export async function execCopydir(source: string, destination: string): Promise<
if (process.platform === 'win32') { if (process.platform === 'win32') {
await cpp.exec(`powershell.exe Copy-Item "${source}" -Destination "${destination}" -Recurse`); await cpp.exec(`powershell.exe Copy-Item "${source}" -Destination "${destination}" -Recurse`);
} else { } else {
await cpp.exec(`cp -r '${source}' '${destination}'`); await cpp.exec(`cp -r '${source}/.' '${destination}'`);
} }
return Promise.resolve(); return Promise.resolve();
...@@ -219,22 +219,16 @@ export function getScriptName(fileNamePrefix: string): string { ...@@ -219,22 +219,16 @@ export function getScriptName(fileNamePrefix: string): string {
} }
} }
/** export function getGpuMetricsCollectorBashScriptContent(scriptFolder: string): string {
* generate script file return `echo $$ > ${scriptFolder}/pid ; METRIC_OUTPUT_DIR=${scriptFolder} python3 -m nni_gpu_tool.gpu_metrics_collector`;
* @param gpuMetricCollectorScriptFolder }
*/
export function getgpuMetricsCollectorScriptContent(gpuMetricCollectorScriptFolder: string): string { export function runGpuMetricsCollector(scriptFolder: string): void {
if (process.platform === 'win32') { if (process.platform === 'win32') {
return String.Format( const scriptPath = path.join(scriptFolder, 'gpu_metrics_collector.ps1');
GPU_INFO_COLLECTOR_FORMAT_WINDOWS, const content = String.Format(GPU_INFO_COLLECTOR_FORMAT_WINDOWS, scriptFolder, path.join(scriptFolder, 'pid'));
gpuMetricCollectorScriptFolder, fs.writeFile(scriptPath, content, { encoding: 'utf8' }, () => { runScript(scriptPath); });
path.join(gpuMetricCollectorScriptFolder, 'pid')
);
} else { } else {
return String.Format( cp.exec(getGpuMetricsCollectorBashScriptContent(scriptFolder), { shell: '/bin/bash' });
GPU_INFO_COLLECTOR_FORMAT_LINUX,
gpuMetricCollectorScriptFolder,
path.join(gpuMetricCollectorScriptFolder, 'pid')
);
} }
} }
...@@ -28,7 +28,7 @@ import { String } from 'typescript-string-operations'; ...@@ -28,7 +28,7 @@ import { String } from 'typescript-string-operations';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { delay } from '../../common/utils'; import { delay } from '../../common/utils';
import { GPUInfo, GPUSummary } from '../common/gpuData'; import { GPUInfo, GPUSummary } from '../common/gpuData';
import { execKill, execMkdir, execRemove, execTail, getgpuMetricsCollectorScriptContent, getScriptName, runScript } from '../common/util'; import { execKill, execMkdir, execRemove, execTail, runGpuMetricsCollector } from '../common/util';
/** /**
* GPUScheduler for local training service * GPUScheduler for local training service
...@@ -43,7 +43,7 @@ class GPUScheduler { ...@@ -43,7 +43,7 @@ class GPUScheduler {
constructor() { constructor() {
this.stopping = false; this.stopping = false;
this.log = getLogger(); this.log = getLogger();
this.gpuMetricCollectorScriptFolder = `${os.tmpdir()}/nni/script`; this.gpuMetricCollectorScriptFolder = `${os.tmpdir()}/${os.userInfo().username}/nni/script`;
} }
public async run(): Promise<void> { public async run(): Promise<void> {
...@@ -101,12 +101,7 @@ class GPUScheduler { ...@@ -101,12 +101,7 @@ class GPUScheduler {
*/ */
private async runGpuMetricsCollectorScript(): Promise<void> { private async runGpuMetricsCollectorScript(): Promise<void> {
await execMkdir(this.gpuMetricCollectorScriptFolder, true); await execMkdir(this.gpuMetricCollectorScriptFolder, true);
//generate gpu_metrics_collector script runGpuMetricsCollector(this.gpuMetricCollectorScriptFolder);
const gpuMetricsCollectorScriptPath: string =
path.join(this.gpuMetricCollectorScriptFolder, getScriptName('gpu_metrics_collector'));
const gpuMetricsCollectorScriptContent: string = getgpuMetricsCollectorScriptContent(this.gpuMetricCollectorScriptFolder);
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
runScript(gpuMetricsCollectorScriptPath);
} }
// tslint:disable:non-literal-fs-path // tslint:disable:non-literal-fs-path
......
...@@ -42,10 +42,10 @@ import { ...@@ -42,10 +42,10 @@ import {
getVersion, uniqueString, unixPathJoin getVersion, uniqueString, unixPathJoin
} from '../../common/utils'; } from '../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { GPU_INFO_COLLECTOR_FORMAT_LINUX, GPUSummary } from '../common/gpuData'; import { GPUSummary } from '../common/gpuData';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execCopydir, execMkdir, execRemove, validateCodeDir } from '../common/util'; import { execCopydir, execMkdir, execRemove, validateCodeDir, getGpuMetricsCollectorBashScriptContent } from '../common/util';
import { GPUScheduler } from './gpuScheduler'; import { GPUScheduler } from './gpuScheduler';
import { import {
HOST_JOB_SHELL_FORMAT, RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta, HOST_JOB_SHELL_FORMAT, RemoteCommandResult, REMOTEMACHINE_TRIAL_COMMAND_FORMAT, RemoteMachineMeta,
...@@ -67,7 +67,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -67,7 +67,7 @@ class RemoteMachineTrainingService implements TrainingService {
private readonly expRootDir: string; private readonly expRootDir: string;
private readonly remoteExpRootDir: string; private readonly remoteExpRootDir: string;
private trialConfig: TrialConfig | undefined; private trialConfig: TrialConfig | undefined;
private readonly gpuScheduler: GPUScheduler; private gpuScheduler?: GPUScheduler;
private readonly jobQueue: string[]; private readonly jobQueue: string[];
private readonly timer: ObservableTimer; private readonly timer: ObservableTimer;
private stopping: boolean = false; private stopping: boolean = false;
...@@ -87,7 +87,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -87,7 +87,6 @@ class RemoteMachineTrainingService implements TrainingService {
this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>(); this.trialJobsMap = new Map<string, RemoteMachineTrialJobDetail>();
this.trialSSHClientMap = new Map<string, Client>(); this.trialSSHClientMap = new Map<string, Client>();
this.machineSSHClientMap = new Map<RemoteMachineMeta, SSHClientManager>(); this.machineSSHClientMap = new Map<RemoteMachineMeta, SSHClientManager>();
this.gpuScheduler = new GPUScheduler(this.machineSSHClientMap);
this.jobQueue = []; this.jobQueue = [];
this.expRootDir = getExperimentRootDir(); this.expRootDir = getExperimentRootDir();
this.remoteExpRootDir = this.getRemoteExperimentRootDir(); this.remoteExpRootDir = this.getRemoteExperimentRootDir();
...@@ -334,8 +333,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -334,8 +333,7 @@ class RemoteMachineTrainingService implements TrainingService {
break; break;
case TrialConfigMetadataKey.MACHINE_LIST: case TrialConfigMetadataKey.MACHINE_LIST:
await this.setupConnections(value); await this.setupConnections(value);
//remove local temp files this.gpuScheduler = new GPUScheduler(this.machineSSHClientMap);
await execRemove(this.getLocalGpuMetricCollectorDir());
break; break;
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG:
const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value); const remoteMachineTrailConfig: TrialConfig = <TrialConfig>JSON.parse(value);
...@@ -399,9 +397,11 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -399,9 +397,11 @@ class RemoteMachineTrainingService implements TrainingService {
* remove gpu reversion when job is not running * remove gpu reversion when job is not running
*/ */
private updateGpuReservation(): void { private updateGpuReservation(): void {
for (const [key, value] of this.trialJobsMap) { if (this.gpuScheduler) {
if (!['WAITING', 'RUNNING'].includes(value.status)) { for (const [key, value] of this.trialJobsMap) {
this.gpuScheduler.removeGpuReservation(key, this.trialJobsMap); if (!['WAITING', 'RUNNING'].includes(value.status)) {
this.gpuScheduler.removeGpuReservation(key, this.trialJobsMap);
}
} }
} }
} }
...@@ -428,34 +428,6 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -428,34 +428,6 @@ class RemoteMachineTrainingService implements TrainingService {
return Promise.resolve(); return Promise.resolve();
} }
/**
* Generate gpu metric collector directory to store temp gpu metric collector script files
*/
private getLocalGpuMetricCollectorDir(): string {
const userName: string = path.basename(os.homedir()); //get current user name of os
return path.join(os.tmpdir(), userName, 'nni', 'scripts');
}
/**
* Generate gpu metric collector shell script in local machine,
* used to run in remote machine, and will be deleted after uploaded from local.
*/
private async generateGpuMetricsCollectorScript(userName: string): Promise<void> {
const gpuMetricCollectorScriptFolder : string = this.getLocalGpuMetricCollectorDir();
await execMkdir(path.join(gpuMetricCollectorScriptFolder, userName));
//generate gpu_metrics_collector.sh
const gpuMetricsCollectorScriptPath: string = path.join(gpuMetricCollectorScriptFolder, userName, 'gpu_metrics_collector.sh');
// This directory is used to store gpu_metrics and pid created by script
const remoteGPUScriptsDir: string = this.getRemoteScriptsPath(userName);
const gpuMetricsCollectorScriptContent: string = String.Format(
GPU_INFO_COLLECTOR_FORMAT_LINUX,
remoteGPUScriptsDir,
unixPathJoin(remoteGPUScriptsDir, 'pid')
);
await fs.promises.writeFile(gpuMetricsCollectorScriptPath, gpuMetricsCollectorScriptContent, { encoding: 'utf8' });
}
private async setupConnections(machineList: string): Promise<void> { private async setupConnections(machineList: string): Promise<void> {
this.log.debug(`Connecting to remote machines: ${machineList}`); this.log.debug(`Connecting to remote machines: ${machineList}`);
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
...@@ -479,24 +451,18 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -479,24 +451,18 @@ class RemoteMachineTrainingService implements TrainingService {
private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise<void> { private async initRemoteMachineOnConnected(rmMeta: RemoteMachineMeta, conn: Client): Promise<void> {
// Create root working directory after ssh connection is ready // Create root working directory after ssh connection is ready
// generate gpu script in local machine first, will copy to remote machine later
await this.generateGpuMetricsCollectorScript(rmMeta.username);
const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni'); const nniRootDir: string = unixPathJoin(getRemoteTmpDir(this.remoteOS), 'nni');
await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn); await SSHClientUtility.remoteExeCommand(`mkdir -p ${this.remoteExpRootDir}`, conn);
// Copy NNI scripts to remote expeirment working directory
const localGpuScriptCollectorDir: string = this.getLocalGpuMetricCollectorDir();
// the directory to store temp scripts in remote machine // the directory to store temp scripts in remote machine
const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username); const remoteGpuScriptCollectorDir: string = this.getRemoteScriptsPath(rmMeta.username);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteGpuScriptCollectorDir}`, conn); await SSHClientUtility.remoteExeCommand(`(umask 0 ; mkdir -p ${remoteGpuScriptCollectorDir})`, conn);
await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn); await SSHClientUtility.remoteExeCommand(`chmod 777 ${nniRootDir} ${nniRootDir}/* ${nniRootDir}/scripts/*`, conn);
//copy gpu_metrics_collector.sh to remote
await SSHClientUtility.copyFileToRemote(path.join(localGpuScriptCollectorDir, rmMeta.username, 'gpu_metrics_collector.sh'),
unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh'), conn);
//Begin to execute gpu_metrics_collection scripts //Begin to execute gpu_metrics_collection scripts
// tslint:disable-next-line: no-floating-promises // tslint:disable-next-line: no-floating-promises
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteGpuScriptCollectorDir, 'gpu_metrics_collector.sh')}`, conn); const script = getGpuMetricsCollectorBashScriptContent(remoteGpuScriptCollectorDir);
SSHClientUtility.remoteExeCommand(`bash -c '${script}'`, conn);
const disposable: Rx.IDisposable = this.timer.subscribe( const disposable: Rx.IDisposable = this.timer.subscribe(
async (tick: number) => { async (tick: number) => {
...@@ -519,6 +485,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -519,6 +485,9 @@ class RemoteMachineTrainingService implements TrainingService {
if (this.trialConfig === undefined) { if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized'); throw new Error('trial config is not initialized');
} }
if (this.gpuScheduler === undefined) {
throw new Error('gpuScheduler is not initialized');
}
const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail: RemoteMachineTrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${trialJobId}`); throw new NNIError(NNIErrorNames.INVALID_JOB_DETAIL, `Invalid job detail information for trial job ${trialJobId}`);
...@@ -630,7 +599,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -630,7 +599,7 @@ class RemoteMachineTrainingService implements TrainingService {
await execMkdir(path.join(trialLocalTempFolder, '.nni')); await execMkdir(path.join(trialLocalTempFolder, '.nni'));
//create tmp trial working folder locally. //create tmp trial working folder locally.
await execCopydir(path.join(this.trialConfig.codeDir, '*'), trialLocalTempFolder); await execCopydir(this.trialConfig.codeDir, trialLocalTempFolder);
const installScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT; const installScriptContent : string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files // Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), installScriptContent, { encoding: 'utf8' });
......
...@@ -68,6 +68,27 @@ def init_logger(logger_file_path, log_level_name='info'): ...@@ -68,6 +68,27 @@ def init_logger(logger_file_path, log_level_name='info'):
sys.stdout = _LoggerFileWrapper(logger_file) sys.stdout = _LoggerFileWrapper(logger_file)
def init_standalone_logger():
"""
Initialize root logger for standalone mode.
This will set NNI's log level to INFO and print its log to stdout.
"""
fmt = '[%(asctime)s] %(levelname)s (%(name)s) %(message)s'
formatter = logging.Formatter(fmt, _time_format)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
nni_logger = logging.getLogger('nni')
nni_logger.addHandler(handler)
nni_logger.setLevel(logging.INFO)
nni_logger.propagate = False
# Following line does not affect NNI loggers, but without this user's logger won't be able to
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info` instead of
# `logging.getLogger('xxx')` in all examples.
logging.basicConfig()
_multi_thread = False _multi_thread = False
_multi_phase = False _multi_phase = False
......
...@@ -34,7 +34,6 @@ class LevelPruner(Pruner): ...@@ -34,7 +34,6 @@ class LevelPruner(Pruner):
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude """An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity. weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices, Learning of Phones and other Consumer Devices,
...@@ -178,17 +177,13 @@ class FPGMPruner(Pruner): ...@@ -178,17 +177,13 @@ class FPGMPruner(Pruner):
assert len(weight.shape) >= 3 assert len(weight.shape) >= 3
assert weight.shape[0] * weight.shape[1] > 2 assert weight.shape[0] * weight.shape[1] > 2
dist_list, idx_list = [], [] dist_list = []
for in_i in range(weight.shape[0]): for in_i in range(weight.shape[0]):
for out_i in range(weight.shape[1]): for out_i in range(weight.shape[1]):
dist_sum = self._get_distance_sum(weight, in_i, out_i) dist_sum = self._get_distance_sum(weight, in_i, out_i)
dist_list.append(dist_sum) dist_list.append((dist_sum, (in_i, out_i)))
idx_list.append([in_i, out_i]) min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:n]
dist_tensor = tf.convert_to_tensor(dist_list) return [x[1] for x in min_gm_kernels]
idx_tensor = tf.constant(idx_list)
_, idx = tf.math.top_k(dist_tensor, k=n)
return tf.gather(idx_tensor, idx)
def _get_distance_sum(self, weight, in_idx, out_idx): def _get_distance_sum(self, weight, in_idx, out_idx):
w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1])) w = tf.reshape(weight, (-1, weight.shape[-2], weight.shape[-1]))
......
from .compressor import LayerInfo, Compressor, Pruner, Quantizer from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import * from .builtin_pruners import *
from .builtin_quantizers import * from .builtin_quantizers import *
from .lottery_ticket import LotteryTicketPruner
...@@ -2,24 +2,44 @@ import logging ...@@ -2,24 +2,44 @@ import logging
import torch import torch
from .compressor import Pruner from .compressor import Pruner
__all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner'] __all__ = ['LevelPruner', 'AGP_Pruner', 'FPGMPruner', 'L1FilterPruner', 'SlimPruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
class LevelPruner(Pruner): class LevelPruner(Pruner):
"""Prune to an exact pruning level specification """
Prune to an exact pruning level specification
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
""" """
config_list: supported keys: Parameters
- sparsity ----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.if_init_list = {} self.if_init_list = {}
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
"""
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name op_name = layer.name
if self.if_init_list.get(op_name, True): if self.if_init_list.get(op_name, True):
...@@ -37,9 +57,9 @@ class LevelPruner(Pruner): ...@@ -37,9 +57,9 @@ class LevelPruner(Pruner):
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude """
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity. weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices, Learning of Phones and other Consumer Devices,
...@@ -48,24 +68,39 @@ class AGP_Pruner(Pruner): ...@@ -48,24 +68,39 @@ class AGP_Pruner(Pruner):
def __init__(self, model, config_list): def __init__(self, model, config_list):
""" """
config_list: supported keys: Parameters
- initial_sparsity ----------
- final_sparsity: you should make sure initial_sparsity <= final_sparsity model : torch.nn.module
- start_epoch: start epoch number begin update mask Model to be pruned
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch config_list : list
- frequency: if you want update every 2 epoch, you can set it 2 List on pruning configs
""" """
super().__init__(model, config_list) super().__init__(model, config_list)
self.now_epoch = 0 self.now_epoch = 0
self.if_init_list = {} self.if_init_list = {}
def calc_mask(self, layer, config): def calc_mask(self, layer, config):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
"""
weight = layer.module.weight.data weight = layer.module.weight.data
op_name = layer.name op_name = layer.name
start_epoch = config.get('start_epoch', 0) start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) and ( if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) \
self.now_epoch - start_epoch) % freq == 0: and (self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight)) mask = self.mask_dict.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
...@@ -82,6 +117,18 @@ class AGP_Pruner(Pruner): ...@@ -82,6 +117,18 @@ class AGP_Pruner(Pruner):
return new_mask return new_mask
def compute_target_sparsity(self, config): def compute_target_sparsity(self, config):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch = config.get('end_epoch', 1) end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 0) start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
...@@ -102,11 +149,20 @@ class AGP_Pruner(Pruner): ...@@ -102,11 +149,20 @@ class AGP_Pruner(Pruner):
return target_sparsity return target_sparsity
def update_epoch(self, epoch): def update_epoch(self, epoch):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for k in self.if_init_list: for k in self.if_init_list.keys():
self.if_init_list[k] = True self.if_init_list[k] = True
class FPGMPruner(Pruner): class FPGMPruner(Pruner):
""" """
A filter pruner via geometric median. A filter pruner via geometric median.
...@@ -135,13 +191,11 @@ class FPGMPruner(Pruner): ...@@ -135,13 +191,11 @@ class FPGMPruner(Pruner):
OUT: number of output channel OUT: number of output channel
IN: number of input channel IN: number of input channel
LEN: filter length LEN: filter length
filter dimensions for Conv2d: filter dimensions for Conv2d:
OUT: number of output channel OUT: number of output channel
IN: number of input channel IN: number of input channel
H: filter height H: filter height
W: filter width W: filter width
Parameters Parameters
---------- ----------
layer : LayerInfo layer : LayerInfo
...@@ -196,7 +250,6 @@ class FPGMPruner(Pruner): ...@@ -196,7 +250,6 @@ class FPGMPruner(Pruner):
for k in w: for k in w:
dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2) dist_sum += torch.dist(k, weight[in_idx, out_idx], p=2)
return dist_sum return dist_sum
Parameters Parameters
---------- ----------
weight: Tensor weight: Tensor
...@@ -206,25 +259,151 @@ class FPGMPruner(Pruner): ...@@ -206,25 +259,151 @@ class FPGMPruner(Pruner):
between this specified filter and all other filters. between this specified filter and all other filters.
in_idx: int in_idx: int
input channel index of specified filter input channel index of specified filter
Returns Returns
------- -------
float32 float32
The total distance The total distance
""" """
logger.debug('weight size: %s', weight.size()) logger.debug('weight size: %s', weight.size())
if len(weight.size()) == 4: # Conv2d if len(weight.size()) == 4: # Conv2d
w = weight.view(-1, weight.size(-2), weight.size(-1)) w = weight.view(-1, weight.size(-2), weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2)) anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1), w.size(2))
elif len(weight.size()) == 3: # Conv1d elif len(weight.size()) == 3: # Conv1d
w = weight.view(-1, weight.size(-1)) w = weight.view(-1, weight.size(-1))
anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1)) anchor_w = weight[out_idx, in_idx].unsqueeze(0).expand(w.size(0), w.size(1))
else: else:
raise RuntimeError('unsupported layer type') raise RuntimeError('unsupported layer type')
x = w - anchor_w x = w - anchor_w
x = (x*x).sum((-2, -1)) x = (x * x).sum((-2, -1))
x = torch.sqrt(x) x = torch.sqrt(x)
return x.sum() return x.sum()
def update_epoch(self, epoch): def update_epoch(self, epoch):
self.epoch_pruned_layers = set() self.epoch_pruned_layers = set()
class L1FilterPruner(Pruner):
"""
A structured pruning algorithm that prunes the filters of smallest magnitude
weights sum in the convolution layers to achieve a preset level of network sparsity.
Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf,
"PRUNING FILTERS FOR EFFICIENT CONVNETS", 2017 ICLR
https://arxiv.org/abs/1608.08710
"""
def __init__(self, model, config_list):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == 'Conv2d', 'L1FilterPruner only supports 2d convolution layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
filters = weight.shape[0]
w_abs = weight.abs()
k = int(filters * config['sparsity'])
if k == 0:
return torch.ones(weight.shape).type_as(weight)
w_abs_structured = w_abs.view(filters, -1).sum(dim=1)
threshold = torch.topk(w_abs_structured.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs_structured, threshold)[:, None, None, None].expand_as(weight).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask
class SlimPruner(Pruner):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def __init__(self, model, config_list):
"""
Parameters
----------
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super().__init__(model, config_list)
self.mask_calculated_ops = set()
weight_list = []
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
config = config_list[0]
for (layer, config) in self.detect_modules_to_compress():
assert layer.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
weight_list.append(layer.module.weight.data.clone())
all_bn_weights = torch.cat(weight_list)
k = int(all_bn_weights.shape[0] * config['sparsity'])
self.global_threshold = torch.topk(all_bn_weights.view(-1), k, largest=False).values.max()
def calc_mask(self, layer, config):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
torch.Tensor
mask of the layer's weight
"""
weight = layer.module.weight.data
op_name = layer.name
op_type = layer.type
assert op_type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning'
if op_name in self.mask_calculated_ops:
assert op_name in self.mask_dict
return self.mask_dict.get(op_name)
mask = torch.ones(weight.size()).type_as(weight)
try:
w_abs = weight.abs()
mask = torch.gt(w_abs, self.global_threshold).type_as(weight)
finally:
self.mask_dict.update({layer.name: mask})
self.mask_calculated_ops.add(layer.name)
return mask
...@@ -13,7 +13,6 @@ class LayerInfo: ...@@ -13,7 +13,6 @@ class LayerInfo:
self._forward = None self._forward = None
class Compressor: class Compressor:
""" """
Abstract base PyTorch compressor Abstract base PyTorch compressor
...@@ -37,7 +36,6 @@ class Compressor: ...@@ -37,7 +36,6 @@ class Compressor:
def detect_modules_to_compress(self): def detect_modules_to_compress(self):
""" """
detect all modules should be compressed, and save the result in `self.modules_to_compress`. detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
""" """
if self.modules_to_compress is None: if self.modules_to_compress is None:
...@@ -49,7 +47,6 @@ class Compressor: ...@@ -49,7 +47,6 @@ class Compressor:
self.modules_to_compress.append((layer, config)) self.modules_to_compress.append((layer, config))
return self.modules_to_compress return self.modules_to_compress
def compress(self): def compress(self):
""" """
Compress the model with algorithm implemented by subclass. Compress the model with algorithm implemented by subclass.
...@@ -218,6 +215,8 @@ class Pruner(Compressor): ...@@ -218,6 +215,8 @@ class Pruner(Compressor):
input_shape : list or tuple input_shape : list or tuple
input shape to onnx model input shape to onnx model
""" """
if self.detect_modules_to_compress() and not self.mask_dict:
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert model_path is not None, 'model_path must be specified' assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules(): for name, m in self.bound_model.named_modules():
if name == "": if name == "":
...@@ -227,25 +226,20 @@ class Pruner(Compressor): ...@@ -227,25 +226,20 @@ class Pruner(Compressor):
mask_sum = mask.sum().item() mask_sum = mask.sum().item()
mask_num = mask.numel() mask_num = mask.numel()
_logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num) _logger.info('Layer: %s Sparsity: %.2f', name, 1 - mask_sum / mask_num)
print('Layer: %s Sparsity: %.2f' % (name, 1 - mask_sum / mask_num))
m.weight.data = m.weight.data.mul(mask) m.weight.data = m.weight.data.mul(mask)
else: else:
_logger.info('Layer: %s NOT compressed', name) _logger.info('Layer: %s NOT compressed', name)
print('Layer: %s NOT compressed' % name)
torch.save(self.bound_model.state_dict(), model_path) torch.save(self.bound_model.state_dict(), model_path)
_logger.info('Model state_dict saved to %s', model_path) _logger.info('Model state_dict saved to %s', model_path)
print('Model state_dict saved to %s' % model_path)
if mask_path is not None: if mask_path is not None:
torch.save(self.mask_dict, mask_path) torch.save(self.mask_dict, mask_path)
_logger.info('Mask dict saved to %s', mask_path) _logger.info('Mask dict saved to %s', mask_path)
print('Mask dict saved to %s' % mask_path)
if onnx_path is not None: if onnx_path is not None:
assert input_shape is not None, 'input_shape must be specified to export onnx model' assert input_shape is not None, 'input_shape must be specified to export onnx model'
# input info needed # input info needed
input_data = torch.Tensor(*input_shape) input_data = torch.Tensor(*input_shape)
torch.onnx.export(self.bound_model, input_data, onnx_path) torch.onnx.export(self.bound_model, input_data, onnx_path)
_logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path) _logger.info('Model in onnx with input shape %s saved to %s', input_data.shape, onnx_path)
print('Model in onnx with input shape %s saved to %s' % (input_data.shape, onnx_path))
class Quantizer(Compressor): class Quantizer(Compressor):
......
import copy
import logging
import torch
from .compressor import Pruner
_logger = logging.getLogger(__name__)
class LotteryTicketPruner(Pruner):
"""
This is a Pytorch implementation of the paper "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks",
following NNI model compression interface.
1. Randomly initialize a neural network f(x;theta_0) (where theta_0 follows D_{theta}).
2. Train the network for j iterations, arriving at parameters theta_j.
3. Prune p% of the parameters in theta_j, creating a mask m.
4. Reset the remaining parameters to their values in theta_0, creating the winning ticket f(x;m*theta_0).
5. Repeat step 2, 3, and 4.
"""
def __init__(self, model, config_list, optimizer, lr_scheduler=None, reset_weights=True):
"""
Parameters
----------
model : pytorch model
The model to be pruned
config_list : list
Supported keys:
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
optimizer : pytorch optimizer
The optimizer for the model
lr_scheduler : pytorch lr scheduler
The lr scheduler for the model if used
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
"""
super().__init__(model, config_list)
self.curr_prune_iteration = None
self.prune_iterations = self._validate_config(config_list)
# save init weights and optimizer
self.reset_weights = reset_weights
if self.reset_weights:
self._model = model
self._optimizer = optimizer
self._model_state = copy.deepcopy(model.state_dict())
self._optimizer_state = copy.deepcopy(optimizer.state_dict())
self._lr_scheduler = lr_scheduler
if lr_scheduler is not None:
self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict())
def _validate_config(self, config_list):
prune_iterations = None
for config in config_list:
assert 'prune_iterations' in config, 'prune_iterations must exist in your config'
assert 'sparsity' in config, 'sparsity must exist in your config'
if prune_iterations is not None:
assert prune_iterations == config['prune_iterations'], 'The values of prune_iterations must be equal in your config'
prune_iterations = config['prune_iterations']
return prune_iterations
def _print_masks(self, print_mask=False):
torch.set_printoptions(threshold=1000)
for op_name in self.mask_dict.keys():
mask = self.mask_dict[op_name]
print('op name: ', op_name)
if print_mask:
print('mask: ', mask)
# calculate current sparsity
mask_num = mask.sum().item()
mask_size = mask.numel()
print('sparsity: ', 1 - mask_num / mask_size)
torch.set_printoptions(profile='default')
def _calc_sparsity(self, sparsity):
keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations)
curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration
return max(1 - curr_keep_ratio, 0)
def _calc_mask(self, weight, sparsity, op_name):
if self.curr_prune_iteration == 0:
mask = torch.ones(weight.shape).type_as(weight)
else:
curr_sparsity = self._calc_sparsity(sparsity)
assert self.mask_dict.get(op_name) is not None
curr_mask = self.mask_dict.get(op_name)
w_abs = weight.abs() * curr_mask
k = int(w_abs.numel() * curr_sparsity)
threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
mask = torch.gt(w_abs, threshold).type_as(weight)
return mask
def calc_mask(self, layer, config):
"""
Generate mask for the given ``weight``.
Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
Returns
-------
tensor
The mask for this weight
"""
assert self.mask_dict.get(layer.name) is not None, 'Please call iteration_start before training'
mask = self.mask_dict[layer.name]
return mask
def get_prune_iterations(self):
"""
Return the range for iterations.
In the first prune iteration, masks are all one, thus, add one more iteration
Returns
-------
list
A list for pruning iterations
"""
return range(self.prune_iterations + 1)
def prune_iteration_start(self):
"""
Control the pruning procedure on updated epoch number.
Should be called at the beginning of the epoch.
"""
if self.curr_prune_iteration is None:
self.curr_prune_iteration = 0
else:
self.curr_prune_iteration += 1
assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations'
modules_to_compress = self.detect_modules_to_compress()
for layer, config in modules_to_compress:
sparsity = config.get('sparsity')
mask = self._calc_mask(layer.module.weight.data, sparsity, layer.name)
self.mask_dict.update({layer.name: mask})
self._print_masks()
# reinit weights back to original after new masks are generated
if self.reset_weights:
self._model.load_state_dict(self._model_state)
self._optimizer.load_state_dict(self._optimizer_state)
if self._lr_scheduler is not None:
self._lr_scheduler.load_state_dict(self._scheduler_state)
...@@ -136,7 +136,6 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -136,7 +136,6 @@ class MsgDispatcher(MsgDispatcherBase):
# data: parameters # data: parameters
id_ = _create_parameter_id() id_ = _create_parameter_id()
_customized_parameter_ids.add(id_) _customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
""" """
...@@ -185,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -185,7 +184,7 @@ class MsgDispatcher(MsgDispatcherBase):
""" """
id_ = data['parameter_id'] id_ = data['parameter_id']
value = data['value'] value = data['value']
if id_ in _customized_parameter_ids: if not id_ or id_ in _customized_parameter_ids:
if not hasattr(self.tuner, '_accept_customized'): if not hasattr(self.tuner, '_accept_customized'):
self.tuner._accept_customized = False self.tuner._accept_customized = False
if not self.tuner._accept_customized: if not self.tuner._accept_customized:
...@@ -194,8 +193,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -194,8 +193,8 @@ class MsgDispatcher(MsgDispatcherBase):
customized = True customized = True
else: else:
customized = False customized = False
self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized, self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized,
trial_job_id=data.get('trial_job_id')) trial_job_id=data.get('trial_job_id'))
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results """Call assessor to process intermediate results
......
...@@ -19,11 +19,29 @@ ...@@ -19,11 +19,29 @@
# ================================================================================================== # ==================================================================================================
import logging
import json_tricks import json_tricks
from ..common import init_standalone_logger
__all__ = [
'get_next_parameter',
'get_experiment_id',
'get_trial_id',
'get_sequence_id',
'send_metric',
]
init_standalone_logger()
_logger = logging.getLogger('nni')
def get_next_parameter(): def get_next_parameter():
pass _logger.warning('Requesting parameter without NNI framework, returning empty dict')
return {
'parameter_id': None,
'parameters': {}
}
def get_experiment_id(): def get_experiment_id():
pass pass
...@@ -37,6 +55,8 @@ def get_sequence_id(): ...@@ -37,6 +55,8 @@ def get_sequence_id():
def send_metric(string): def send_metric(string):
metric = json_tricks.loads(string) metric = json_tricks.loads(string)
if metric['type'] == 'FINAL': if metric['type'] == 'FINAL':
print('Final result:', metric['value']) _logger.info('Final result: %s', metric['value'])
elif metric['type'] == 'PERIODICAL': elif metric['type'] == 'PERIODICAL':
print('Intermediate result:', metric['value']) _logger.info('Intermediate result: %s (Index %s)', metric['value'], metric['sequence'])
else:
_logger.error('Unexpected metric: %s', string)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment