Unverified Commit c785655e authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #207 from microsoft/master

merge master
parents 9fae194a d6b61e2f
......@@ -25,9 +25,9 @@ import * as path from 'path';
import * as component from '../common/component';
import { DataStore, MetricDataRecord, TrialJobInfo } from '../common/datastore';
import { NNIError, NNIErrorNames } from '../common/errors';
import { isNewExperiment } from '../common/experimentStartupInfo';
import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, Manager, TrialJobStatistics} from '../common/manager';
import { ExperimentProfile, Manager, TrialJobStatistics, ExperimentStartUpMode } from '../common/manager';
import { ValidationSchemas } from './restValidationSchemas';
import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils';
......@@ -72,6 +72,8 @@ class NNIRestHandler {
this.addTrialJob(router);
this.cancelTrialJob(router);
this.getMetricData(router);
this.getMetricDataByRange(router);
this.getLatestMetricData(router);
this.exportData(router);
// Express-joi-validator configuration
......@@ -86,11 +88,11 @@ class NNIRestHandler {
return router;
}
private handle_error(err: Error, res: Response, isFatal: boolean = false): void {
private handle_error(err: Error, res: Response, isFatal: boolean = false, errorCode: number = 500): void {
if (err instanceof NNIError && err.name === NNIErrorNames.NOT_FOUND) {
res.status(404);
} else {
res.status(500);
res.status(errorCode);
}
res.send({
error: err.message
......@@ -169,7 +171,7 @@ class NNIRestHandler {
this.handle_error(err, res);
});
} else {
this.nniManager.resumeExperiment().then(() => {
this.nniManager.resumeExperiment(isReadonly()).then(() => {
res.send();
}).catch((err: Error) => {
// Resume experiment is a step of initialization, so any exception thrown is a fatal
......@@ -262,6 +264,28 @@ class NNIRestHandler {
});
}
private getMetricDataByRange(router: Router): void {
router.get('/metric-data-range/:min_seq_id/:max_seq_id', async (req: Request, res: Response) => {
const minSeqId = Number(req.params.min_seq_id);
const maxSeqId = Number(req.params.max_seq_id);
this.nniManager.getMetricDataByRange(minSeqId, maxSeqId).then((metricsData: MetricDataRecord[]) => {
res.send(metricsData);
}).catch((err: Error) => {
this.handle_error(err, res);
});
});
}
private getLatestMetricData(router: Router): void {
router.get('/metric-data-latest/', async (req: Request, res: Response) => {
this.nniManager.getLatestMetricData().then((metricsData: MetricDataRecord[]) => {
res.send(metricsData);
}).catch((err: Error) => {
this.handle_error(err, res);
});
});
}
private exportData(router: Router): void {
router.get('/export-data', (req: Request, res: Response) => {
this.nniManager.exportData().then((exportedData: string) => {
......
......@@ -170,18 +170,18 @@ export namespace ValidationSchemas {
classFileName: joi.string(),
className: joi.string(),
classArgs: joi.any(),
gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('')
checkpointDir: joi.string().allow(''),
gpuIndices: joi.string()
}),
tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner'),
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner'),
codeDir: joi.string(),
classFileName: joi.string(),
className: joi.string(),
classArgs: joi.any(),
gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow(''),
includeIntermediateResults: joi.boolean()
includeIntermediateResults: joi.boolean(),
gpuIndices: joi.string()
}),
assessor: joi.object({
builtinAssessorName: joi.string().valid('Medianstop', 'Curvefitting'),
......@@ -189,7 +189,6 @@ export namespace ValidationSchemas {
classFileName: joi.string(),
className: joi.string(),
classArgs: joi.any(),
gpuNum: joi.number().min(0),
checkpointDir: joi.string().allow('')
}),
clusterMetaData: joi.array().items(joi.object({
......@@ -210,7 +209,7 @@ export namespace ValidationSchemas {
startTime: joi.number(),
endTime: joi.number(),
logDir: joi.string(),
maxSequenceId: joi.number()
nextSequenceId: joi.number()
}
};
}
......@@ -85,9 +85,9 @@ export class MockedNNIManager extends Manager {
// tslint:disable-next-line:no-http-string
url: 'http://test',
workingDirectory: '/tmp/mocked',
sequenceId: 0,
form: {
jobType: 'TRIAL'
sequenceId: 0,
hyperParameters: { value: '', index: 0 }
}
};
deferred.resolve(jobDetail);
......@@ -129,6 +129,12 @@ export class MockedNNIManager extends Manager {
public getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError();
}
public getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError();
}
public getLatestMetricData(): Promise<MetricDataRecord[]> {
throw new MethodNotImplementedError();
}
public getExperimentProfile(): Promise<ExperimentProfile> {
const profile: ExperimentProfile = {
params: {
......@@ -148,7 +154,7 @@ export class MockedNNIManager extends Manager {
execDuration: 0,
startTime: Date.now(),
endTime: Date.now(),
maxSequenceId: 0,
nextSequenceId: 0,
revision: 0
};
......
......@@ -156,7 +156,7 @@ export async function execRemove(directory: string): Promise<void> {
*/
export async function execKill(pid: string): Promise<void> {
if (process.platform === 'win32') {
await cpp.exec(`cmd /c taskkill /PID ${pid} /T /F`);
await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /T /F`);
} else {
await cpp.exec(`pkill -P ${pid}`);
}
......
......@@ -25,7 +25,7 @@ import * as path from 'path';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import {
JobApplicationForm, NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
......@@ -55,7 +55,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
super();
this.fcJobInfoCollector = new FrameworkControllerJobInfoCollector(this.trialJobsMap);
this.experimentId = getExperimentId();
this.nextTrialSequenceId = -1;
}
public async run(): Promise<void> {
......@@ -77,7 +76,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
}
}
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.fcClusterConfig === undefined) {
throw new Error('frameworkcontrollerClusterConfig is not initialized');
}
......@@ -91,14 +90,13 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
}
const trialJobId: string = uniqueString(5);
const curTrialSequenceId: number = this.generateSequenceId();
// Set trial's NFS working folder
const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const frameworkcontrollerJobName: string = `nniexp${this.experimentId}trial${trialJobId}`.toLowerCase();
//Generate the port used for taskRole
this.generateContainerPort();
await this.prepareRunScript(trialLocalTempFolder, curTrialSequenceId, trialJobId, trialWorkingFolder, form);
await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
//upload code files
const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder);
......@@ -113,7 +111,6 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
trialWorkingFolder,
form,
frameworkcontrollerJobName,
curTrialSequenceId,
trialJobOutputUrl
);
......@@ -248,8 +245,8 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
return `${portScript} . /mnt/frameworkbarrier/injector.sh && ${command}`;
}
private async prepareRunScript(trialLocalTempFolder: string, curTrialSequenceId: number, trialJobId: string,
trialWorkingFolder: string, form: JobApplicationForm): Promise<void> {
private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string,
trialWorkingFolder: string, form: TrialJobApplicationForm): Promise<void> {
if (this.fcTrialConfig === undefined) {
throw new Error('frameworkcontroller trial config is not initialized');
}
......@@ -264,16 +261,16 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
for (const taskRole of this.fcTrialConfig.taskRoles) {
const runScriptContent: string =
await this.generateRunScript('frameworkcontroller', trialJobId, trialWorkingFolder,
this.generateCommandScript(taskRole.command), curTrialSequenceId.toString(),
this.generateCommandScript(taskRole.command), form.sequenceId.toString(),
taskRole.name, taskRole.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, `run_${taskRole.name}.sh`), runScriptContent, { encoding: 'utf8' });
}
// Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>form);
if (trialForm !== undefined && trialForm.hyperParameters !== undefined) {
await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
trialForm.hyperParameters.value, { encoding: 'utf8' });
if (form !== undefined) {
await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(form.hyperParameters)),
form.hyperParameters.value, { encoding: 'utf8' });
}
}
......
......@@ -27,7 +27,7 @@ import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import {
JobApplicationForm, NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
} from '../../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
......@@ -59,7 +59,6 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
super();
this.kubeflowJobInfoCollector = new KubeflowJobInfoCollector(this.trialJobsMap);
this.experimentId = getExperimentId();
this.nextTrialSequenceId = -1;
this.log.info('Construct Kubeflow training service.');
}
......@@ -84,7 +83,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
this.log.info('Kubeflow training service exit.');
}
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.kubernetesCRDClient === undefined) {
throw new Error('Kubeflow job operator client is undefined');
}
......@@ -96,10 +95,9 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
const trialJobId: string = uniqueString(5);
const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
const kubeflowJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
const curTrialSequenceId: number = this.generateSequenceId();
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
//prepare the runscript
await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, curTrialSequenceId, form);
await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
//upload files to sotrage
const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder);
let initStatus: TrialJobStatus = 'WAITING';
......@@ -113,7 +111,6 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
trialWorkingFolder,
form,
kubeflowJobName,
curTrialSequenceId,
trialJobOutputUrl
);
......@@ -236,8 +233,8 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
return Promise.resolve(trialJobOutputUrl);
}
private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string, trialWorkingFolder: string, curTrialSequenceId: number,
form: JobApplicationForm): Promise<void> {
private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string, trialWorkingFolder: string,
form: TrialJobApplicationForm): Promise<void> {
if (this.kubeflowClusterConfig === undefined) {
throw new Error('Kubeflow Cluster config is not initialized');
}
......@@ -262,7 +259,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (kubeflowTrialConfig.worker !== undefined) {
const workerRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
kubeflowTrialConfig.worker.command,
curTrialSequenceId.toString(), 'worker',
form.sequenceId.toString(), 'worker',
kubeflowTrialConfig.worker.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_worker.sh'), workerRunScriptContent, { encoding: 'utf8' });
}
......@@ -272,7 +269,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (tensorflowTrialConfig.ps !== undefined) {
const psRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
tensorflowTrialConfig.ps.command,
curTrialSequenceId.toString(),
form.sequenceId.toString(),
'ps', tensorflowTrialConfig.ps.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_ps.sh'), psRunScriptContent, { encoding: 'utf8' });
}
......@@ -281,16 +278,15 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
if (pytorchTrialConfig.master !== undefined) {
const masterRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
pytorchTrialConfig.master.command,
curTrialSequenceId.toString(), 'master',
form.sequenceId.toString(), 'master',
pytorchTrialConfig.master.gpuNum);
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_master.sh'), masterRunScriptContent, { encoding: 'utf8' });
}
}
// Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>form);
if (trialForm !== undefined && trialForm.hyperParameters !== undefined) {
await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
trialForm.hyperParameters.value, { encoding: 'utf8' });
if (form !== undefined) {
await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(form.hyperParameters)),
form.hyperParameters.value, { encoding: 'utf8' });
}
}
......
......@@ -19,7 +19,7 @@
'use strict';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
/**
* KubeflowTrialJobDetail
......@@ -33,21 +33,19 @@ export class KubernetesTrialJobDetail implements TrialJobDetail {
public tags?: string[];
public url?: string;
public workingDirectory: string;
public form: JobApplicationForm;
public form: TrialJobApplicationForm;
public kubernetesJobName: string;
public sequenceId: number;
public queryJobFailedCount: number;
constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm,
kubernetesJobName: string, sequenceId: number, url: string) {
workingDirectory: string, form: TrialJobApplicationForm,
kubernetesJobName: string, url: string) {
this.id = id;
this.status = status;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.kubernetesJobName = kubernetesJobName;
this.sequenceId = sequenceId;
this.tags = [];
this.queryJobFailedCount = 0;
this.url = url;
......
......@@ -26,7 +26,7 @@ import * as azureStorage from 'azure-storage';
import { EventEmitter } from 'events';
import { Base64 } from 'js-base64';
import { String } from 'typescript-string-operations';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
NNIManagerIpConfig, TrialJobDetail, TrialJobMetric
......@@ -53,7 +53,6 @@ abstract class KubernetesTrainingService {
protected readonly trialLocalNFSTempFolder: string;
protected stopping: boolean = false;
protected experimentId! : string;
protected nextTrialSequenceId: number;
protected kubernetesRestServerPort?: number;
protected readonly CONTAINER_MOUNT_PATH: string;
protected azureStorageClient?: azureStorage.FileService;
......@@ -74,7 +73,6 @@ abstract class KubernetesTrainingService {
this.trialJobsMap = new Map<string, KubernetesTrialJobDetail>();
this.trialLocalNFSTempFolder = path.join(getExperimentRootDir(), 'trials-nfs-tmp');
this.experimentId = getExperimentId();
this.nextTrialSequenceId = -1;
this.CONTAINER_MOUNT_PATH = '/tmp/mount';
this.genericK8sClient = new GeneralK8sClient();
this.logCollection = 'none';
......@@ -93,10 +91,8 @@ abstract class KubernetesTrainingService {
const jobs: TrialJobDetail[] = [];
for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') {
jobs.push(await this.getTrialJob(key));
}
}
return Promise.resolve(jobs);
}
......@@ -222,14 +218,6 @@ abstract class KubernetesTrainingService {
return Promise.resolve();
}
protected generateSequenceId(): number {
if (this.nextTrialSequenceId === -1) {
this.nextTrialSequenceId = getInitTrialSequenceId();
}
return this.nextTrialSequenceId++;
}
// tslint:disable: no-unsafe-any no-any
protected async createAzureStorage(vaultName: string, valutKeyName: string, accountName: string, azureShare: string): Promise<void> {
try {
......
......@@ -26,10 +26,10 @@ import * as path from 'path';
import * as ts from 'tail-stream';
import * as tkill from 'tree-kill';
import { NNIError, NNIErrorNames } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm,
HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService';
import {
......@@ -76,21 +76,19 @@ class LocalTrialJobDetail implements TrialJobDetail {
public tags?: string[];
public url?: string;
public workingDirectory: string;
public form: JobApplicationForm;
public sequenceId: number;
public form: TrialJobApplicationForm;
public pid?: number;
public gpuIndices?: number[];
constructor(
id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, sequenceId: number) {
workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id;
this.status = status;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.url = `file://localhost:${workingDirectory}`;
this.sequenceId = sequenceId;
this.gpuIndices = [];
}
}
......@@ -125,7 +123,6 @@ class LocalTrainingService implements TrainingService {
private initialized: boolean;
private stopping: boolean;
private rootDir!: string;
private trialSequenceId: number;
private readonly experimentId! : string;
private gpuScheduler!: GPUScheduler;
private readonly occupiedGpuIndexNumMap: Map<number, number>;
......@@ -145,7 +142,6 @@ class LocalTrainingService implements TrainingService {
this.initialized = false;
this.stopping = false;
this.log = getLogger();
this.trialSequenceId = -1;
this.experimentId = getExperimentId();
this.jobStreamMap = new Map<string, ts.Stream>();
this.log.info('Construct local machine training service.');
......@@ -169,10 +165,8 @@ class LocalTrainingService implements TrainingService {
const jobs: TrialJobDetail[] = [];
for (const key of this.jobMap.keys()) {
const trialJob: TrialJobDetail = await this.getTrialJob(key);
if (trialJob.form.jobType === 'TRIAL') {
jobs.push(trialJob);
}
}
return jobs;
}
......@@ -182,9 +176,6 @@ class LocalTrainingService implements TrainingService {
if (trialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, 'Trial job not found');
}
if (trialJob.form.jobType === 'HOST') {
return this.getHostJob(trialJobId);
}
if (trialJob.status === 'RUNNING') {
const alive: boolean = await isAlive(trialJob.pid);
if (!alive) {
......@@ -219,18 +210,14 @@ class LocalTrainingService implements TrainingService {
this.eventEmitter.off('metric', listener);
}
public submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
if (form.jobType === 'HOST') {
return this.runHostJob(<HostJobApplicationForm>form);
} else if (form.jobType === 'TRIAL') {
public submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobId: string = uniqueString(5);
const trialJobDetail: LocalTrialJobDetail = new LocalTrialJobDetail(
trialJobId,
'WAITING',
Date.now(),
path.join(this.rootDir, 'trials', trialJobId),
form,
this.generateSequenceId()
form
);
this.jobQueue.push(trialJobId);
this.jobMap.set(trialJobId, trialJobDetail);
......@@ -238,9 +225,6 @@ class LocalTrainingService implements TrainingService {
this.log.debug(`submitTrialJob: return: ${JSON.stringify(trialJobDetail)} `);
return Promise.resolve(trialJobDetail);
} else {
return Promise.reject(new Error(`Job form not supported: ${JSON.stringify(form)}`));
}
}
/**
......@@ -248,16 +232,12 @@ class LocalTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.jobMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
await this.writeParameterFile(trialJobDetail.workingDirectory, form.hyperParameters);
return trialJobDetail;
}
......@@ -279,13 +259,7 @@ class LocalTrainingService implements TrainingService {
return Promise.resolve();
}
if (trialJob.form.jobType === 'TRIAL') {
tkill(trialJob.pid, 'SIGKILL');
} else if (trialJob.form.jobType === 'HOST') {
await cpp.exec(`pkill -9 -P ${trialJob.pid}`);
} else {
throw new Error(`Job type not supported: ${trialJob.form.jobType}`);
}
this.setTrialJobStatus(trialJob, getJobCancelStatus(isEarlyStopped));
return Promise.resolve();
......@@ -409,7 +383,7 @@ class LocalTrainingService implements TrainingService {
{ key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id },
{ key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.sequenceId.toString() },
{ key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.form.sequenceId.toString() },
{ key: 'MULTI_PHASE', value: this.isMultiPhase.toString() }
];
if (gpuNum !== undefined) {
......@@ -516,7 +490,7 @@ class LocalTrainingService implements TrainingService {
const script: string[] = [];
if (process.platform === 'win32') {
script.push(
`cmd /c ${localTrialConfig.command} 2>${path.join(workingDirectory, 'stderr')}`,
`cmd.exe /c ${localTrialConfig.command} 2>${path.join(workingDirectory, 'stderr')}`,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`,
`Write $LASTEXITCODE " " $NOW_DATE | Out-File ${path.join(workingDirectory, '.nni', 'state')} -NoNewline -encoding utf8`);
......@@ -562,7 +536,7 @@ class LocalTrainingService implements TrainingService {
const scriptName: string = getScriptName('run');
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, scriptName),
runScriptContent.join(getNewLine()), { encoding: 'utf8', mode: 0o777 });
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters);
await this.writeParameterFile(trialJobDetail.workingDirectory, trialJobDetail.form.hyperParameters);
const trialJobProcess: cp.ChildProcess = runScript(path.join(trialJobDetail.workingDirectory, scriptName));
this.setTrialJobStatus(trialJobDetail, 'RUNNING');
trialJobDetail.startTime = Date.now();
......@@ -589,60 +563,10 @@ class LocalTrainingService implements TrainingService {
this.jobStreamMap.set(trialJobDetail.id, stream);
}
private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
const jobId: string = uniqueString(5);
const workDir: string = path.join(this.rootDir, 'hostjobs', jobId);
await cpp.exec(`mkdir -p ${workDir}`);
const wrappedCmd: string = `cd ${workDir} && ${form.cmd}>stdout 2>stderr`;
this.log.debug(`runHostJob: command: ${wrappedCmd}`);
const process: cp.ChildProcess = cp.exec(wrappedCmd);
const jobDetail: LocalTrialJobDetail = {
id: jobId,
status: 'RUNNING',
submitTime: Date.now(),
workingDirectory: workDir,
form: form,
sequenceId: this.generateSequenceId(),
pid: process.pid
};
this.jobMap.set(jobId, jobDetail);
this.log.debug(`runHostJob: return: ${JSON.stringify(jobDetail)} `);
return jobDetail;
}
private async getHostJob(jobId: string): Promise<TrialJobDetail> {
const jobDetail: LocalTrialJobDetail | undefined = this.jobMap.get(jobId);
if (jobDetail === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `Host Job not found: ${jobId}`);
}
try {
await cpp.exec(`kill -0 ${jobDetail.pid}`);
return jobDetail;
} catch (error) {
if (error instanceof Error) {
this.log.debug(`getHostJob: error: ${error.message}`);
this.jobMap.delete(jobId);
throw new NNIError(NNIErrorNames.NOT_FOUND, `Host Job not found: ${error.message}`);
} else {
throw error;
}
}
}
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
private generateSequenceId(): number {
if (this.trialSequenceId === -1) {
this.trialSequenceId = getInitTrialSequenceId();
}
return this.trialSequenceId++;
}
}
export { LocalTrainingService };
......@@ -19,7 +19,7 @@
'use strict';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
/**
* PAI trial job detail
......@@ -34,20 +34,18 @@ export class PAITrialJobDetail implements TrialJobDetail {
public tags?: string[];
public url?: string;
public workingDirectory: string;
public form: JobApplicationForm;
public sequenceId: number;
public form: TrialJobApplicationForm;
public hdfsLogPath: string;
public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName : string,
submitTime: number, workingDirectory: string, form: JobApplicationForm, sequenceId: number, hdfsLogPath: string) {
submitTime: number, workingDirectory: string, form: TrialJobApplicationForm, hdfsLogPath: string) {
this.id = id;
this.status = status;
this.paiJobName = paiJobName;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.sequenceId = sequenceId;
this.tags = [];
this.hdfsLogPath = hdfsLogPath;
}
......
......@@ -30,10 +30,10 @@ import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import { MethodNotImplementedError } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService,
HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, generateParamFileName,
......@@ -70,7 +70,6 @@ class PAITrainingService implements TrainingService {
private readonly paiTokenUpdateInterval: number;
private readonly experimentId! : string;
private readonly paiJobCollector : PAIJobInfoCollector;
private nextTrialSequenceId: number;
private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig;
private copyExpCodeDirPromise?: Promise<void>;
......@@ -90,7 +89,6 @@ class PAITrainingService implements TrainingService {
this.expRootDir = path.join('/nni', 'experiments', getExperimentId());
this.experimentId = getExperimentId();
this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
this.nextTrialSequenceId = -1;
this.paiTokenUpdateInterval = 7200000; //2hours
this.logCollection = 'none';
this.log.info('Construct OpenPAI training service.');
......@@ -112,10 +110,8 @@ class PAITrainingService implements TrainingService {
const jobs: TrialJobDetail[] = [];
for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') {
jobs.push(await this.getTrialJob(key));
}
}
return Promise.resolve(jobs);
}
......@@ -142,7 +138,7 @@ class PAITrainingService implements TrainingService {
this.metricsEmitter.off('metric', listener);
}
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`);
}
......@@ -151,7 +147,6 @@ class PAITrainingService implements TrainingService {
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5);
const trialSequenceId: number = this.generateSequenceId();
//TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
......@@ -171,7 +166,6 @@ class PAITrainingService implements TrainingService {
Date.now(),
trialWorkingFolder,
form,
trialSequenceId,
hdfsLogPath);
this.trialJobsMap.set(trialJobId, trialJobDetail);
......@@ -181,16 +175,12 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
await this.writeParameterFile(trialJobId, form.hyperParameters);
return trialJobDetail;
}
......@@ -397,11 +387,10 @@ class PAITrainingService implements TrainingService {
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>trialJobDetail.form);
if (trialForm !== undefined) {
if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile(
path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
trialForm.hyperParameters.value, { encoding: 'utf8' }
path.join(trialLocalTempFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
}
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
......@@ -416,7 +405,7 @@ class PAITrainingService implements TrainingService {
`$PWD/${trialJobId}/nnioutput`,
trialJobId,
this.experimentId,
trialJobDetail.sequenceId,
trialJobDetail.form.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
......@@ -507,14 +496,6 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
private generateSequenceId(): number {
if (this.nextTrialSequenceId === -1) {
this.nextTrialSequenceId = getInitTrialSequenceId();
}
return this.nextTrialSequenceId++;
}
private async statusCheckingLoop(): Promise<void> {
while (!this.stopping) {
try {
......
......@@ -22,7 +22,7 @@
import * as fs from 'fs';
import { Client, ConnectConfig } from 'ssh2';
import { Deferred } from 'ts-deferred';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUInfo, GPUSummary } from '../common/gpuData';
/**
......@@ -82,20 +82,18 @@ export class RemoteMachineTrialJobDetail implements TrialJobDetail {
public tags?: string[];
public url?: string;
public workingDirectory: string;
public form: JobApplicationForm;
public sequenceId: number;
public form: TrialJobApplicationForm;
public rmMeta?: RemoteMachineMeta;
public isEarlyStopped?: boolean;
public gpuIndices: GPUInfo[];
constructor(id: string, status: TrialJobStatus, submitTime: number,
workingDirectory: string, form: JobApplicationForm, sequenceId: number) {
workingDirectory: string, form: TrialJobApplicationForm) {
this.id = id;
this.status = status;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.sequenceId = sequenceId;
this.tags = [];
this.gpuIndices = [];
}
......
......@@ -30,11 +30,11 @@ import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component';
import { NNIError, NNIErrorNames } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer';
import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService, TrialJobApplicationForm,
HyperParameters, NNIManagerIpConfig, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import {
......@@ -172,10 +172,8 @@ class RemoteMachineTrainingService implements TrainingService {
const deferred: Deferred<TrialJobDetail[]> = new Deferred<TrialJobDetail[]>();
for (const [key, value] of this.trialJobsMap) {
if (value.form.jobType === 'TRIAL') {
jobs.push(await this.getTrialJob(key));
}
}
deferred.resolve(jobs);
return deferred.promise;
......@@ -228,14 +226,11 @@ class RemoteMachineTrainingService implements TrainingService {
* @param form trial job description form
*/
// tslint:disable-next-line:informative-docs
public async submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail> {
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.trialConfig === undefined) {
throw new Error('trial config is not initialized');
}
if (form.jobType === 'HOST') {
return this.runHostJob(<HostJobApplicationForm>form);
} else if (form.jobType === 'TRIAL') {
// Generate trial job id(random)
const trialJobId: string = uniqueString(5);
const trialWorkingFolder: string = unixPathJoin(this.remoteExpRootDir, 'trials', trialJobId);
......@@ -245,16 +240,12 @@ class RemoteMachineTrainingService implements TrainingService {
'WAITING',
Date.now(),
trialWorkingFolder,
form,
this.generateSequenceId()
form
);
this.jobQueue.push(trialJobId);
this.trialJobsMap.set(trialJobId, trialJobDetail);
return Promise.resolve(trialJobDetail);
} else {
return Promise.reject(new Error(`Job form not supported: ${JSON.stringify(form)}, jobType should be HOST or TRIAL.`));
}
}
/**
......@@ -262,21 +253,17 @@ class RemoteMachineTrainingService implements TrainingService {
* @param trialJobId trial job id
* @param form job application form
*/
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
const rmMeta: RemoteMachineMeta | undefined = (<RemoteMachineTrialJobDetail>trialJobDetail).rmMeta;
if (rmMeta !== undefined) {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters, rmMeta);
await this.writeParameterFile(trialJobId, form.hyperParameters, rmMeta);
} else {
throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
}
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail;
}
......@@ -558,7 +545,7 @@ class RemoteMachineTrainingService implements TrainingService {
await this.allocateSSHClientForTrial(trialJobDetail);
await this.launchTrialOnScheduledMachine(
trialJobId, trialWorkingFolder, <TrialJobApplicationForm>trialJobDetail.form, rmScheduleInfo);
trialJobId, trialWorkingFolder, trialJobDetail.form, rmScheduleInfo);
trialJobDetail.status = 'RUNNING';
trialJobDetail.url = `file://${rmScheduleInfo.rmMeta.ip}:${trialWorkingFolder}`;
......@@ -628,7 +615,7 @@ class RemoteMachineTrainingService implements TrainingService {
trialWorkingFolder,
trialJobId,
getExperimentId(),
trialJobDetail.sequenceId.toString(),
trialJobDetail.form.sequenceId.toString(),
this.isMultiPhase,
unixPathJoin(trialWorkingFolder, '.nni', 'jobpid'),
command,
......@@ -657,38 +644,6 @@ class RemoteMachineTrainingService implements TrainingService {
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(trialWorkingFolder, 'run.sh')}`, sshClient);
}
private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
const rmMeta: RemoteMachineMeta = this.getRmMetaByHost(form.host);
const sshClientManager: SSHClientManager | undefined = this.machineSSHClientMap.get(rmMeta);
if (sshClientManager === undefined) {
throw new Error('sshClient not found.');
}
const sshClient: Client = sshClientManager.getFirstSSHClient();
const jobId: string = uniqueString(5);
const localDir: string = path.join(this.expRootDir, 'hostjobs-local', jobId);
const remoteDir: string = this.getHostJobRemoteDir(jobId);
await cpp.exec(`mkdir -p ${localDir}`);
await SSHClientUtility.remoteExeCommand(`mkdir -p ${remoteDir}`, sshClient);
const runScriptContent: string = String.Format(
HOST_JOB_SHELL_FORMAT, remoteDir, path.join(remoteDir, 'jobpid'), form.cmd, path.join(remoteDir, 'code')
);
await fs.promises.writeFile(path.join(localDir, 'run.sh'), runScriptContent, { encoding: 'utf8' });
await SSHClientUtility.copyFileToRemote(
path.join(localDir, 'run.sh'), unixPathJoin(remoteDir, 'run.sh'), sshClient);
// tslint:disable-next-line: no-floating-promises
SSHClientUtility.remoteExeCommand(`bash ${unixPathJoin(remoteDir, 'run.sh')}`, sshClient);
const jobDetail: RemoteMachineTrialJobDetail = new RemoteMachineTrialJobDetail(
jobId, 'RUNNING', Date.now(), remoteDir, form, this.generateSequenceId()
);
jobDetail.rmMeta = rmMeta;
jobDetail.startTime = Date.now();
this.trialJobsMap.set(jobId, jobDetail);
this.log.debug(`runHostJob: return: ${JSON.stringify(jobDetail)} `);
return jobDetail;
}
private getRmMetaByHost(host: string): RemoteMachineMeta {
for (const [rmMeta, client] of this.machineSSHClientMap.entries()) {
if (rmMeta.ip === host) {
......@@ -765,13 +720,7 @@ class RemoteMachineTrainingService implements TrainingService {
}
let jobpidPath: string;
if (trialJobDetail.form.jobType === 'TRIAL') {
jobpidPath = unixPathJoin(trialJobDetail.workingDirectory, '.nni', 'jobpid');
} else if (trialJobDetail.form.jobType === 'HOST') {
jobpidPath = unixPathJoin(this.getHostJobRemoteDir(jobId), 'jobpid');
} else {
throw new Error(`Job type not supported: ${trialJobDetail.form.jobType}`);
}
return jobpidPath;
}
......@@ -791,14 +740,6 @@ class RemoteMachineTrainingService implements TrainingService {
await SSHClientUtility.copyFileToRemote(localFilepath, unixPathJoin(trialWorkingFolder, fileName), sshClient);
}
private generateSequenceId(): number {
if (this.trialSequenceId === -1) {
this.trialSequenceId = getInitTrialSequenceId();
}
return this.trialSequenceId++;
}
}
export { RemoteMachineTrainingService };
......@@ -76,7 +76,7 @@ describe('Unit Test for LocalTrainingService', () => {
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
......@@ -95,7 +95,7 @@ describe('Unit Test for LocalTrainingService', () => {
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
......
......@@ -24,6 +24,7 @@ import * as chaiAsPromised from 'chai-as-promised';
import * as fs from 'fs';
import * as tmp from 'tmp';
import * as component from '../../common/component';
import { TrialJobApplicationForm } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAITrainingService } from '../pai/paiTrainingService';
......@@ -84,8 +85,12 @@ describe('Unit Test for PAITrainingService', () => {
console.log(`paiCluster is ${paiCluster}`)
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.PAI_CLUSTER_CONFIG, paiCluster);
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, paiTrialConfig);
const form: TrialJobApplicationForm = {
sequenceId: 0,
hyperParameters: { value: '', index: 0 }
};
try {
const trialDetail = await paiTrainingService.submitTrialJob({jobType : 'TRIAL'});
const trialDetail = await paiTrainingService.submitTrialJob(form);
chai.expect(trialDetail.status).to.be.equals('WAITING');
} catch(error) {
console.log('Submit job failed:' + error);
......
......@@ -99,7 +99,7 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
await remoteMachineTrainingService.setClusterMetadata(
TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`);
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
......@@ -137,7 +137,7 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
// submit job
const form: TrialJobApplicationForm = {
jobType: 'TRIAL',
sequenceId: 0,
hyperParameters: {
value: 'mock hyperparameters',
index: 0
......
from .compressor import LayerInfo, Compressor, Pruner, Quantizer
from .builtin_pruners import *
from .builtin_quantizers import *
This diff is collapsed.
import logging
import tensorflow as tf
from .compressor import Quantizer
__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ]
_logger = logging.getLogger(__name__)
class NaiveQuantizer(Quantizer):
"""quantize weight to 8 bits
"""
def __init__(self, config_list):
super().__init__(config_list)
self.layer_scale = { }
def quantize_weight(self, weight, config, op_name, **kwargs):
new_scale = tf.reduce_max(tf.abs(weight)) / 127
scale = tf.maximum(self.layer_scale.get(op_name, tf.constant(0.0)), new_scale)
self.layer_scale[op_name] = scale
orig_type = weight.dtype
return tf.cast(tf.cast(weight / scale, tf.int8), orig_type) * scale
class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
def __init__(self, config_list):
"""
config_list: supported keys:
- q_bits
"""
super().__init__(config_list)
def quantize_weight(self, weight, config, **kwargs):
a = tf.stop_gradient(tf.reduce_min(weight))
b = tf.stop_gradient(tf.reduce_max(weight))
n = tf.cast(2 ** config['q_bits'], tf.float32)
scale = b-a/(n-1)
# use gradient_override_map to change round to idetity for gradient
with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}):
qw = tf.round((weight-a)/scale)*scale +a
return qw
class DoReFaQuantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
"""
def __init__(self, config_list):
"""
config_list: supported keys:
- q_bits
"""
super().__init__(config_list)
def quantize_weight(self, weight, config, **kwargs):
a = tf.math.tanh(weight)
b = a/(2*tf.reduce_max(tf.abs(weight))) + 0.5
scale = pow(2, config['q_bits'] - 1)
# use gradient_override_map to change round to idetity for gradient
with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}):
qw = tf.round(b*scale)/scale
r_qw = 2 * qw - 1
return r_qw
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment