Unverified Commit f9ee589c authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #222 from microsoft/master

merge master
parents 36e6e350 4f3ee9cb
......@@ -19,19 +19,17 @@ export interface ParameterFileMeta {
* PAI Training service Rest server, provides rest API to support pai job metrics update
*
*/
@component.Singleton
export class PAIJobRestServer extends ClusterJobRestServer {
private parameterFileMetaList: ParameterFileMeta[] = [];
protected parameterFileMetaList: ParameterFileMeta[] = [];
@Inject
private readonly paiTrainingService: PAITrainingService;
protected readonly paiTrainingService: PAITrainingService;
/**
* constructor to provide NNIRestServer's own rest property, e.g. port
*/
constructor() {
constructor (paiTrainingService: PAITrainingService) {
super();
this.paiTrainingService = component.get(PAITrainingService);
this.paiTrainingService = paiTrainingService;
}
protected handleTrialMetrics(jobId: string, metrics: any[]): void {
......
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import {TrialConfig} from '../../common/trialConfig';
/**
* PAI trial configuration
*/
export class NNIPAIK8STrialConfig extends TrialConfig {
public readonly cpuNum: number;
public readonly memoryMB: number;
public readonly image: string;
public virtualCluster?: string;
public readonly nniManagerNFSMountPath: string;
public readonly containerNFSMountPath: string;
public readonly paiStoragePlugin: string;
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, nniManagerNFSMountPath: string, containerNFSMountPath: string,
paiStoragePlugin: string, virtualCluster?: string) {
super(command, codeDir, gpuNum);
this.cpuNum = cpuNum;
this.memoryMB = memoryMB;
this.image = image;
this.virtualCluster = virtualCluster;
this.nniManagerNFSMountPath = nniManagerNFSMountPath;
this.containerNFSMountPath = containerNFSMountPath;
this.paiStoragePlugin = paiStoragePlugin;
}
}
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
export const PAI_INSTALL_NNI_SHELL_FORMAT: string =
`#!/bin/bash
if python3 -c 'import nni' > /dev/null 2>&1; then
# nni module is already installed, skip
return
else
# Install nni
python3 -m pip install --user nni
fi`;
export const PAI_K8S_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& ls $NNI_SYS_DIR \
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \
--nni_manager_version '{9}' --log_collection '{10}'`;
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import * as path from 'path';
// tslint:disable-next-line:no-implicit-dependencies
import * as request from 'request';
import * as component from '../../../common/component';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import {
HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../../common/trainingService';
import { delay, generateParamFileName,
getExperimentRootDir, getIPV4Address, getVersion, uniqueString, unixPathJoin } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { execMkdir, validateCodeDir, execCopydir } from '../../common/util';
import { PAI_K8S_TRIAL_COMMAND_FORMAT } from './paiK8SData';
import { NNIPAIK8STrialConfig } from './paiK8SConfig';
import { PAITrainingService } from '../paiTrainingService';
import { PAIClusterConfig, PAITrialJobDetail } from '../paiConfig';
import { PAIJobRestServer } from '../paiJobRestServer';
const yaml = require('js-yaml');
/**
* Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI
*/
@component.Singleton
class PAIK8STrainingService extends PAITrainingService {
protected paiTrialConfig: NNIPAIK8STrialConfig | undefined;
constructor() {
super();
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
if(this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if(this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
}
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
if (this.paiClusterConfig === undefined) {
this.log.error('pai cluster config is not initialized');
break;
}
this.paiTrialConfig = <NNIPAIK8STrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.paiTrialConfig.codeDir);
break;
case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True');
break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
this.log.error(`Uknown key: ${key}`);
}
}
//TODO: update trial parameters
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
return trialJobDetail;
}
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`);
}
if (this.paiTrialConfig === undefined) {
throw new Error(`paiTrialConfig not initialized!`);
}
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5);
//TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
const logPath: string = path.join(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId, trialJobId);
const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
trialJobId,
'WAITING',
paiJobName,
Date.now(),
trialWorkingFolder,
form,
logPath);
this.trialJobsMap.set(trialJobId, trialJobDetail);
this.jobQueue.push(trialJobId);
return trialJobDetail;
}
public generateJobConfigInYamlFormat(trialJobId: string, command: string) {
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
const jobName = `nni_exp_${this.experimentId}_trial_${trialJobId}`
const paiJobConfig: any = {
protocolVersion: 2,
name: jobName,
type: 'job',
jobRetryCount: 0,
prerequisites: [
{
type: 'dockerimage',
uri: this.paiTrialConfig.image,
name: 'docker_image_0'
}
],
taskRoles: {
taskrole: {
instances: 1,
completion: {
minFailedInstances: 1,
minSucceededInstances: -1
},
taskRetryCount: 0,
dockerImage: 'docker_image_0',
resourcePerInstance: {
gpu: this.paiTrialConfig.gpuNum,
cpu: this.paiTrialConfig.cpuNum,
memoryMB: this.paiTrialConfig.memoryMB
},
commands: [
command
]
}
},
extras: {
'com.microsoft.pai.runtimeplugin': [
{
plugin: this.paiTrialConfig.paiStoragePlugin
}
],
submitFrom: 'submit-job-v2'
}
}
if (this.paiTrialConfig.virtualCluster) {
paiJobConfig.defaults= {
virtualCluster: this.paiTrialConfig.virtualCluster
}
}
return yaml.safeDump(paiJobConfig);
}
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
const deferred: Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
}
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized');
}
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer is not initialized');
}
this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;
// Step 1. Prepare PAI job configuration
const trialLocalFolder: string = path.join(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId, trialJobId);
//create trial local working folder locally.
await execMkdir(trialLocalFolder);
const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local files
await fs.promises.writeFile(path.join(trialLocalFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local working folders
if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile(
path.join(trialLocalFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
}
//Copy codeDir files to local working folder
await execCopydir(this.paiTrialConfig.codeDir, trialLocalFolder);
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : '';
const containerWorkingDir: string = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/${trialJobId}`;
const nniPaiTrialCommand: string = String.Format(
PAI_K8S_TRIAL_COMMAND_FORMAT,
`${containerWorkingDir}`,
`${containerWorkingDir}/nnioutput`,
trialJobId,
this.experimentId,
trialJobDetail.form.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
this.paiRestServerPort,
version,
this.logCollection
)
.replace(/\r\n|\n|\r/gm, '');
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiJobConfig = this.generateJobConfigInYamlFormat(trialJobId, nniPaiTrialCommand);
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
headers: {
'Content-Type': 'text/yaml',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;
this.log.error(errorMessage);
trialJobDetail.status = 'FAILED';
} else {
trialJobDetail.submitTime = Date.now();
}
deferred.resolve(true);
});
return deferred.promise;
}
}
export { PAIK8STrainingService };
......@@ -4,9 +4,9 @@
import * as fs from 'fs';
import * as path from 'path';
import { Deferred } from 'ts-deferred';
import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger } from '../../common/log';
import { unixPathJoin } from '../../common/utils';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger } from '../../../common/log';
import { unixPathJoin } from '../../../common/utils';
/**
* HDFS client utility, including copy file/directory
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import {TrialConfig} from '../../common/trialConfig';
/**
* Task role for PAI
*/
export class PAITaskRole {
// Name for the task role
public readonly name: string;
// Number of tasks for the task role, no less than 1
public readonly taskNumber: number;
// CPU number for one task in the task role, no less than 1
public readonly cpuNumber: number;
// Memory for one task in the task role, no less than 100
public readonly memoryMB: number;
// GPU number for one task in the task role, no less than 0
public readonly gpuNumber: number;
// Executable command for tasks in the task role, can not be empty
public readonly command: string;
//Shared memory for one task in the task role
public readonly shmMB?: number;
//portList to specify the port used in container
public portList?: PortListMetaData[];
/**
* Constructor
* @param name Name for the task role
* @param taskNumber Number of tasks for the task role, no less than 1
* @param cpuNumber CPU number for one task in the task role, no less than 1
* @param memoryMB Memory for one task in the task role, no less than 100
* @param gpuNumber GPU number for one task in the task role, no less than 0
* @param command Executable command for tasks in the task role, can not be empty
*/
constructor(name: string, taskNumber: number, cpuNumber: number, memoryMB: number, gpuNumber: number,
command: string, shmMB?: number, portList?: PortListMetaData[]) {
this.name = name;
this.taskNumber = taskNumber;
this.cpuNumber = cpuNumber;
this.memoryMB = memoryMB;
this.gpuNumber = gpuNumber;
this.command = command;
this.shmMB = shmMB;
this.portList = portList;
}
}
/**
* Trial job configuration submitted to PAI
*/
export class PAIJobConfig {
// Name for the job, need to be unique
public readonly jobName: string;
// URL pointing to the Docker image for all tasks in the job
public readonly image: string;
// Code directory on HDFS
public readonly codeDir: string;
//authentication file used for private Docker registry
public readonly authFile?: string;
// List of taskRole, one task role at least
public taskRoles: PAITaskRole[];
//The virtual cluster job runs on.
public readonly virtualCluster: string;
/**
* Constructor
* @param jobName Name for the job, need to be unique
* @param image URL pointing to the Docker image for all tasks in the job
* @param dataDir Data directory existing on HDFS
* @param outputDir Output directory on HDFS
* @param taskRoles List of taskRole, one task role at least
*/
constructor(jobName: string, image: string, codeDir: string,
taskRoles: PAITaskRole[], virtualCluster: string, authFile?: string) {
this.jobName = jobName;
this.image = image;
this.codeDir = codeDir;
this.taskRoles = taskRoles;
this.virtualCluster = virtualCluster;
this.authFile = authFile;
}
}
/**
* portList data structure used in PAI taskRole
*/
export class PortListMetaData {
public readonly label: string = '';
public readonly beginAt: number = 0;
public readonly portNumber: number = 0;
}
/**
* PAI trial configuration
*/
export class NNIPAITrialConfig extends TrialConfig {
public readonly cpuNum: number;
public readonly memoryMB: number;
public readonly image: string;
//The virtual cluster job runs on. If omitted, the job will run on default virtual cluster
public virtualCluster?: string;
//Shared memory for one task in the task role
public shmMB?: number;
//authentication file used for private Docker registry
public authFile?: string;
//portList to specify the port used in container
public portList?: PortListMetaData[];
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, virtualCluster?: string, shmMB?: number, authFile?: string, portList?: PortListMetaData[]) {
super(command, codeDir, gpuNum);
this.cpuNum = cpuNum;
this.memoryMB = memoryMB;
this.image = image;
this.virtualCluster = virtualCluster;
this.shmMB = shmMB;
this.authFile = authFile;
this.portList = portList;
}
}
......@@ -3,37 +3,7 @@
'use strict';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
/**
* PAI trial job detail
*/
export class PAITrialJobDetail implements TrialJobDetail {
public id: string;
public status: TrialJobStatus;
public paiJobName: string;
public submitTime: number;
public startTime?: number;
public endTime?: number;
public tags?: string[];
public url?: string;
public workingDirectory: string;
public form: TrialJobApplicationForm;
public hdfsLogPath: string;
public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName: string,
submitTime: number, workingDirectory: string, form: TrialJobApplicationForm, hdfsLogPath: string) {
this.id = id;
this.status = status;
this.paiJobName = paiJobName;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.tags = [];
this.hdfsLogPath = hdfsLogPath;
}
}
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../../common/trainingService';
export const PAI_INSTALL_NNI_SHELL_FORMAT: string =
`#!/bin/bash
......@@ -46,7 +16,7 @@ else
fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
`export NNI_PLATFORM=paiYarn NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \
--pai_hdfs_output_dir '{9}' --pai_hdfs_host '{10}' --pai_user_name {11} --nni_hdfs_exp_dir '{12}' --webhdfs_path '/webhdfs/api/v1' \
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import * as request from 'request';
import * as component from '../../../common/component';
import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import {
HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../../common/trainingService';
import { delay, generateParamFileName,
getExperimentRootDir, getIPV4Address, getVersion, uniqueString, unixPathJoin } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { execMkdir, validateCodeDir } from '../../common/util';
import { HDFSClientUtility } from './hdfsClientUtility';
import { NNIPAITrialConfig, PAIJobConfig, PAITaskRole } from './paiYarnConfig';
import { PAI_LOG_PATH_FORMAT, PAI_TRIAL_COMMAND_FORMAT } from './paiYarnData';
import { PAIJobInfoCollector } from '../paiJobInfoCollector';
import { PAITrainingService } from '../paiTrainingService';
import { PAIClusterConfig, PAITrialJobDetail } from '../paiConfig';
import * as WebHDFS from 'webhdfs';
import { PAIJobRestServer, ParameterFileMeta } from '../paiJobRestServer';
/**
* Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI
*/
@component.Singleton
class PAIYarnTrainingService extends PAITrainingService {
private hdfsClient: any;
private copyExpCodeDirPromise?: Promise<void>;
private copyAuthFilePromise?: Promise<void>;
private paiTrialConfig?: NNIPAITrialConfig;
constructor() {
super();
}
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiBaseClusterConfig not initialized!`);
}
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5);
//TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
const hdfsLogPath: string = String.Format(
PAI_LOG_PATH_FORMAT,
this.paiClusterConfig.host,
hdfsOutputDir
);
const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
trialJobId,
'WAITING',
paiJobName,
Date.now(),
trialWorkingFolder,
form,
hdfsLogPath);
this.trialJobsMap.set(trialJobId, trialJobDetail);
this.jobQueue.push(trialJobId);
return trialJobDetail;
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.PAI_YARN_CLUSTER_CONFIG:
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIYarnTrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.hdfsClient = WebHDFS.createClient({
user: this.paiClusterConfig.userName,
// Refer PAI document for Pylon mapping https://github.com/Microsoft/pai/tree/master/docs/pylon
port: 80,
path: '/webhdfs/api/v1',
host: this.paiClusterConfig.host
});
if(this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if(this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
} else {
throw new Error('pai cluster config format error, please set password or token!');
}
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
if (this.paiClusterConfig === undefined) {
this.log.error('pai cluster config is not initialized');
break;
}
this.paiTrialConfig = <NNIPAITrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.paiTrialConfig.codeDir);
// Copy experiment files from local folder to HDFS
this.copyExpCodeDirPromise = HDFSClientUtility.copyDirectoryToHdfs(
this.paiTrialConfig.codeDir,
HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName),
this.hdfsClient
);
// Upload authFile to hdfs
if (this.paiTrialConfig.authFile) {
this.authFileHdfsPath = unixPathJoin(HDFSClientUtility.hdfsExpRootDir(this.paiClusterConfig.userName), 'authFile');
this.copyAuthFilePromise = HDFSClientUtility.copyFileToHdfs(this.paiTrialConfig.authFile, this.authFileHdfsPath, this.hdfsClient);
}
break;
case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True');
break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
throw new Error(`Uknown key: ${key}`);
}
}
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
const deferred: Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
}
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized');
}
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer is not initialized');
}
this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;
// Make sure experiment code files is copied from local to HDFS
if (this.copyExpCodeDirPromise !== undefined) {
await this.copyExpCodeDirPromise;
}
//Make sure authFile is copied from local to HDFS
if (this.paiTrialConfig.authFile) {
await this.copyAuthFilePromise;
}
// Step 1. Prepare PAI job configuration
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
//create tmp trial working folder locally.
await execMkdir(trialLocalTempFolder);
const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local tmp folders
if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile(
path.join(trialLocalTempFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
}
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : '';
const nniPaiTrialCommand: string = String.Format(
PAI_TRIAL_COMMAND_FORMAT,
// PAI will copy job's codeDir into /root directory
`$PWD/${trialJobId}`,
`$PWD/${trialJobId}/nnioutput`,
trialJobId,
this.experimentId,
trialJobDetail.form.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
this.paiRestServerPort,
hdfsOutputDir,
this.paiClusterConfig.host,
this.paiClusterConfig.userName,
HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName),
version,
this.logCollection
)
.replace(/\r\n|\n|\r/gm, '');
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiTaskRoles: PAITaskRole[] = [
new PAITaskRole(
`nni_trail_${trialJobId}`,
// Task role number
1,
// Task CPU number
this.paiTrialConfig.cpuNum,
// Task memory
this.paiTrialConfig.memoryMB,
// Task GPU number
this.paiTrialConfig.gpuNum,
// Task command
nniPaiTrialCommand,
// Task shared memory
this.paiTrialConfig.shmMB,
// Task portList
this.paiTrialConfig.portList
)
];
const paiJobConfig: PAIJobConfig = new PAIJobConfig(
// Job name
trialJobDetail.paiJobName,
// Docker image
this.paiTrialConfig.image,
// codeDir
`$PAI_DEFAULT_FS_URI${hdfsCodeDir}`,
// PAI Task roles
paiTaskRoles,
// Add Virutal Cluster
this.paiTrialConfig.virtualCluster === undefined ? 'default' : this.paiTrialConfig.virtualCluster.toString(),
//Task auth File
this.authFileHdfsPath
);
// Step 2. Upload code files in codeDir onto HDFS
try {
await HDFSClientUtility.copyDirectoryToHdfs(trialLocalTempFolder, hdfsCodeDir, this.hdfsClient);
} catch (error) {
this.log.error(`PAI Training service: copy ${this.paiTrialConfig.codeDir} to HDFS ${hdfsCodeDir} failed, error is ${error}`);
trialJobDetail.status = 'FAILED'; // eslint-disable-line require-atomic-updates
return true;
}
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}/jobs`,
method: 'POST',
json: true,
body: paiJobConfig,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body.message}`;
trialJobDetail.status = 'FAILED';
deferred.resolve(true);
} else {
trialJobDetail.submitTime = Date.now();
deferred.resolve(true);
}
});
return deferred.promise;
}
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
await this.writeParameterFile(trialJobId, form.hyperParameters);
return trialJobDetail;
}
protected async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('PAI trial config is not initialized');
}
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const hpFileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, hpFileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsHpFilePath: string = path.join(hdfsCodeDir, hpFileName);
await HDFSClientUtility.copyFileToHdfs(localFilepath, hdfsHpFilePath, this.hdfsClient);
await this.postParameterFileMeta({
experimentId: this.experimentId,
trialId: trialJobId,
filePath: hdfsHpFilePath
});
}
protected postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer not implemented!');
}
const req: request.Options = {
uri: `${this.paiJobRestServer.endPoint}${this.paiJobRestServer.apiRootUrl}/parameter-file-meta`,
method: 'POST',
json: true,
body: parameterFileMeta
};
request(req, (err: Error, res: request.Response) => {
if (err) {
deferred.reject(err);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
}
export { PAIYarnTrainingService };
......@@ -3,12 +3,12 @@
'use strict';
import {TrialConfig} from '../common/trialConfig';
import {TrialConfig} from '../../common/trialConfig';
/**
* PAI configuration to run trials
*/
export class PAITrialConfig extends TrialConfig {
export class PAIYarnTrialConfig extends TrialConfig {
public readonly cpuNum: number;
public readonly memoryMB: number;
public readonly image: string;
......
......@@ -9,7 +9,7 @@ import * as os from 'os';
import * as path from 'path';
import * as tmp from 'tmp';
import { cleanupUnitTest, prepareUnitTest, uniqueString } from '../../common/utils';
import { HDFSClientUtility } from '../pai/hdfsClientUtility';
import { HDFSClientUtility } from '../pai/paiYarn/hdfsClientUtility';
var WebHDFS = require('webhdfs');
var rmdir = require('rmdir');
......
......@@ -11,14 +11,14 @@ import * as component from '../../common/component';
import { TrialJobApplicationForm } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAITrainingService } from '../pai/paiTrainingService';
import { PAIYarnTrainingService } from '../pai/paiYarn/paiYarnTrainingService';
// TODO: copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name
const mockedTrialPath: string = './training_service/test/mockedTrial.py'
fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
describe('Unit Test for PAITrainingService', () => {
describe('Unit Test for PAIYarnTrainingService', () => {
let skip: boolean = false;
let testPaiClusterInfo: any;
let paiCluster: any;
......@@ -33,7 +33,7 @@ describe('Unit Test for PAITrainingService', () => {
skip = true;
}
let paiTrainingService: PAITrainingService;
let paiYarnTrainingService: PAIYarnTrainingService;
console.log(tmp.dirSync().name);
......@@ -51,15 +51,15 @@ describe('Unit Test for PAITrainingService', () => {
if (skip) {
return;
}
paiTrainingService = component.get(PAITrainingService);
paiTrainingService.run();
paiYarnTrainingService = component.get(PAIYarnTrainingService);
paiYarnTrainingService.run();
});
afterEach(() => {
if (skip) {
return;
}
paiTrainingService.cleanUp();
paiYarnTrainingService.cleanUp();
});
it('Get PAI token', async () => {
......@@ -67,14 +67,14 @@ describe('Unit Test for PAITrainingService', () => {
return;
}
console.log(`paiCluster is ${paiCluster}`)
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.PAI_CLUSTER_CONFIG, paiCluster);
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, paiTrialConfig);
await paiYarnTrainingService.setClusterMetadata(TrialConfigMetadataKey.PAI_YARN_CLUSTER_CONFIG, paiCluster);
await paiYarnTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, paiTrialConfig);
const form: TrialJobApplicationForm = {
sequenceId: 0,
hyperParameters: { value: '', index: 0 }
};
try {
const trialDetail = await paiTrainingService.submitTrialJob(form);
const trialDetail = await paiYarnTrainingService.submitTrialJob(form);
chai.expect(trialDetail.status).to.be.equals('WAITING');
} catch(error) {
console.log('Submit job failed:' + error);
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trial import *
from .smartparam import *
from .nas_utils import training_update
from .env_vars import dispatcher_env_vars
if dispatcher_env_vars.SDK_PROCESS != 'dispatcher':
from .trial import *
from .smartparam import *
from .nas_utils import training_update
class NoMoreTrialError(Exception):
def __init__(self, ErrorInfo):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
__main__.py
'''
import os
import sys
import argparse
import logging
import json
import importlib
import base64
from .common import enable_multi_thread, enable_multi_phase
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
......@@ -29,99 +27,67 @@ def augment_classargs(input_class_args, classname):
input_class_args[key] = value
return input_class_args
def create_builtin_class_instance(classname, jsonstr_args, is_advisor=False):
if is_advisor:
if classname not in AdvisorModuleName or \
importlib.util.find_spec(AdvisorModuleName[classname]) is None:
raise RuntimeError('Advisor module is not found: {}'.format(classname))
class_module = importlib.import_module(AdvisorModuleName[classname])
class_constructor = getattr(class_module, AdvisorClassName[classname])
else:
if classname not in ModuleName or \
importlib.util.find_spec(ModuleName[classname]) is None:
raise RuntimeError('Tuner module is not found: {}'.format(classname))
class_module = importlib.import_module(ModuleName[classname])
class_constructor = getattr(class_module, ClassName[classname])
if jsonstr_args:
class_args = augment_classargs(json.loads(jsonstr_args), classname)
else:
class_args = augment_classargs({}, classname)
if class_args:
instance = class_constructor(**class_args)
else:
instance = class_constructor()
def create_builtin_class_instance(class_name, class_args, builtin_module_dict, builtin_class_dict):
if class_name not in builtin_module_dict or \
importlib.util.find_spec(builtin_module_dict[class_name]) is None:
raise RuntimeError('Builtin module is not found: {}'.format(class_name))
class_module = importlib.import_module(builtin_module_dict[class_name])
class_constructor = getattr(class_module, builtin_class_dict[class_name])
if class_args is None:
class_args = {}
class_args = augment_classargs(class_args, class_name)
instance = class_constructor(**class_args)
return instance
def create_customized_class_instance(class_dir, class_filename, classname, jsonstr_args):
if not os.path.isfile(os.path.join(class_dir, class_filename)):
def create_customized_class_instance(class_params):
code_dir = class_params.get('codeDir')
class_filename = class_params.get('classFileName')
class_name = class_params.get('className')
class_args = class_params.get('classArgs')
if not os.path.isfile(os.path.join(code_dir, class_filename)):
raise ValueError('Class file not found: {}'.format(
os.path.join(class_dir, class_filename)))
sys.path.append(class_dir)
os.path.join(code_dir, class_filename)))
sys.path.append(code_dir)
module_name = os.path.splitext(class_filename)[0]
class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, classname)
if jsonstr_args:
class_args = json.loads(jsonstr_args)
instance = class_constructor(**class_args)
else:
instance = class_constructor()
class_constructor = getattr(class_module, class_name)
if class_args is None:
class_args = {}
instance = class_constructor(**class_args)
return instance
def parse_args():
parser = argparse.ArgumentParser(description='parse command line parameters.')
parser.add_argument('--advisor_class_name', type=str, required=False,
help='Advisor class name, the class must be a subclass of nni.MsgDispatcherBase')
parser.add_argument('--advisor_class_filename', type=str, required=False,
help='Advisor class file path')
parser.add_argument('--advisor_args', type=str, required=False,
help='Parameters pass to advisor __init__ constructor')
parser.add_argument('--advisor_directory', type=str, required=False,
help='Advisor directory')
parser.add_argument('--tuner_class_name', type=str, required=False,
help='Tuner class name, the class must be a subclass of nni.Tuner')
parser.add_argument('--tuner_class_filename', type=str, required=False,
help='Tuner class file path')
parser.add_argument('--tuner_args', type=str, required=False,
help='Parameters pass to tuner __init__ constructor')
parser.add_argument('--tuner_directory', type=str, required=False,
help='Tuner directory')
parser.add_argument('--assessor_class_name', type=str, required=False,
help='Assessor class name, the class must be a subclass of nni.Assessor')
parser.add_argument('--assessor_args', type=str, required=False,
help='Parameters pass to assessor __init__ constructor')
parser.add_argument('--assessor_directory', type=str, required=False,
help='Assessor directory')
parser.add_argument('--assessor_class_filename', type=str, required=False,
help='Assessor class file path')
parser.add_argument('--multi_phase', action='store_true')
parser.add_argument('--multi_thread', action='store_true')
flags, _ = parser.parse_known_args()
return flags
def main():
'''
main function.
'''
parser = argparse.ArgumentParser(description='Dispatcher command line parser')
parser.add_argument('--exp_params', type=str, required=True)
args, _ = parser.parse_known_args()
exp_params_decode = base64.b64decode(args.exp_params).decode('utf-8')
logger.debug('decoded exp_params: [%s]', exp_params_decode)
exp_params = json.loads(exp_params_decode)
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4))
args = parse_args()
if args.multi_thread:
if exp_params.get('multiThread'):
enable_multi_thread()
if args.multi_phase:
if exp_params.get('multiPhase'):
enable_multi_phase()
if args.advisor_class_name:
if exp_params.get('advisor') is not None:
# advisor is enabled and starts to run
_run_advisor(args)
_run_advisor(exp_params)
else:
# tuner (and assessor) is enabled and starts to run
tuner = _create_tuner(args)
if args.assessor_class_name:
assessor = _create_assessor(args)
assert exp_params.get('tuner') is not None
tuner = _create_tuner(exp_params)
if exp_params.get('assessor') is not None:
assessor = _create_assessor(exp_params)
else:
assessor = None
dispatcher = MsgDispatcher(tuner, assessor)
......@@ -139,17 +105,14 @@ def main():
raise
def _run_advisor(args):
if args.advisor_class_name in AdvisorModuleName:
def _run_advisor(exp_params):
if exp_params.get('advisor').get('builtinAdvisorName') in AdvisorModuleName:
dispatcher = create_builtin_class_instance(
args.advisor_class_name,
args.advisor_args, True)
exp_params.get('advisor').get('builtinAdvisorName'),
exp_params.get('advisor').get('classArgs'),
AdvisorModuleName, AdvisorClassName)
else:
dispatcher = create_customized_class_instance(
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
dispatcher = create_customized_class_instance(exp_params.get('advisor'))
if dispatcher is None:
raise AssertionError('Failed to create Advisor instance')
try:
......@@ -159,33 +122,27 @@ def _run_advisor(args):
raise
def _create_tuner(args):
if args.tuner_class_name in ModuleName:
def _create_tuner(exp_params):
if exp_params.get('tuner').get('builtinTunerName') in ModuleName:
tuner = create_builtin_class_instance(
args.tuner_class_name,
args.tuner_args)
exp_params.get('tuner').get('builtinTunerName'),
exp_params.get('tuner').get('classArgs'),
ModuleName, ClassName)
else:
tuner = create_customized_class_instance(
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)
tuner = create_customized_class_instance(exp_params.get('tuner'))
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
return tuner
def _create_assessor(args):
if args.assessor_class_name in ModuleName:
def _create_assessor(exp_params):
if exp_params.get('assessor').get('builtinAssessorName') in ModuleName:
assessor = create_builtin_class_instance(
args.assessor_class_name,
args.assessor_args)
exp_params.get('assessor').get('builtinAssessorName'),
exp_params.get('assessor').get('classArgs'),
ModuleName, ClassName)
else:
assessor = create_customized_class_instance(
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
assessor = create_customized_class_instance(exp_params.get('assessor'))
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
return assessor
......
......@@ -100,7 +100,7 @@ def get_bits_length(config, quant_type):
class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in:
"""Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
......@@ -227,20 +227,17 @@ class DoReFaQuantizer(Quantizer):
(https://arxiv.org/abs/1606.06160)
"""
def __init__(self, model, config_list):
"""
config_list: supported keys:
- q_bits
"""
super().__init__(model, config_list)
def quantize_weight(self, weight, config, **kwargs):
weight_bits = get_bits_length(config, 'weight')
out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, config['q_bits'])
out = self.quantize(out, weight_bits)
out = 2 * out -1
return out
def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale
return output
return output
\ No newline at end of file
......@@ -250,6 +250,10 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
"""
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.quant_grad = QuantGrad
def quantize_weight(self, weight, config, op, op_type, op_name):
"""
quantize should overload this method to quantize weight.
......@@ -262,7 +266,7 @@ class Quantizer(Compressor):
config : dict
the configuration for weight quantization
"""
raise NotImplementedError("Quantizer must overload quantize_weight()")
raise NotImplementedError('Quantizer must overload quantize_weight()')
def quantize_output(self, output, config, op, op_type, op_name):
"""
......@@ -276,7 +280,7 @@ class Quantizer(Compressor):
config : dict
the configuration for output quantization
"""
raise NotImplementedError("Quantizer must overload quantize_output()")
raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, config, op, op_type, op_name):
"""
......@@ -290,7 +294,7 @@ class Quantizer(Compressor):
config : dict
the configuration for inputs quantization
"""
raise NotImplementedError("Quantizer must overload quantize_input()")
raise NotImplementedError('Quantizer must overload quantize_input()')
def _instrument_layer(self, layer, config):
......@@ -305,62 +309,93 @@ class Quantizer(Compressor):
the configuration for quantization
"""
assert layer._forward is None, 'Each model can only be compressed once'
assert "quant_types" in config, 'must provide quant_types in config'
assert isinstance(config["quant_types"], list), 'quant_types must be list type'
assert "quant_bits" in config, 'must provide quant_bits in config'
assert isinstance(config["quant_bits"], int) or isinstance(config["quant_bits"], dict), 'quant_bits must be dict type or int type'
assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert 'quant_bits' in config, 'must provide quant_bits in config'
assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type'
if isinstance(config["quant_bits"], dict):
for quant_type in config["quant_types"]:
assert quant_type in config["quant_bits"], 'bits length for %s must be specified in quant_bits dict' % quant_type
if isinstance(config['quant_bits'], dict):
for quant_type in config['quant_types']:
assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config["quant_types"]:
if 'weight' in config['quant_types']:
if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)
layer._forward = layer.module.forward
def new_forward(*inputs):
if 'input' in config["quant_types"]:
inputs = straight_through_quantize_input.apply(inputs, self, config, layer)
if 'input' in config['quant_types']:
inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)
if 'weight' in config["quant_types"] and _check_weight(layer.module):
weight = layer.module.weight.data
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = new_weight
if 'weight' in config['quant_types'] and _check_weight(layer.module):
new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
layer.module.weight = new_weight
result = layer._forward(*inputs)
layer.module.weight.data = weight
else:
result = layer._forward(*inputs)
if 'output' in config["quant_types"]:
result = straight_through_quantize_output.apply(result, self, config, layer)
if 'output' in config['quant_types']:
result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result
layer.module.forward = new_forward
class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
class straight_through_quantize_output(torch.autograd.Function):
class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@staticmethod
def forward(ctx, output, quantizer, config, layer):
return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name)
def quant_backward(tensor, grad_output, quant_type):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
Parameters
----------
tensor : Tensor
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
class straight_through_quantize_input(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, quantizer, config, layer):
return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name)
Returns
-------
tensor
gradient of the input of quantization operation
"""
return grad_output
@staticmethod
def backward(ctx, grad_output):
# Straight-through estimator
return grad_output, None, None, None
def forward(ctx, tensor, quant_type, quant_func, config, layer):
ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name)
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None
def _check_weight(module):
try:
return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor)
return isinstance(module.weight.data, torch.Tensor)
except AttributeError:
return False
......@@ -16,6 +16,7 @@ _trial_env_var_names = [
]
_dispatcher_env_var_names = [
'SDK_PROCESS',
'NNI_MODE',
'NNI_CHECKPOINT_DIRECTORY',
'NNI_LOG_DIRECTORY',
......
......@@ -11,7 +11,6 @@ import logging
import hyperopt as hp
import numpy as np
from nni.tuner import Tuner
from nni.nas_utils import rewrite_nas_space
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
logger = logging.getLogger('hyperopt_AutoML')
......@@ -226,7 +225,6 @@ class HyperoptTuner(Tuner):
return hp.anneal.suggest
raise RuntimeError('Not support tuner algorithm in hyperopt.')
@rewrite_nas_space
def update_search_space(self, search_space):
"""
Update search space definition in tuner by search_space in parameters.
......
......@@ -4,6 +4,9 @@
import logging
import os
import torch
import torch.nn as nn
_logger = logging.getLogger(__name__)
......@@ -44,11 +47,28 @@ class LRSchedulerCallback(Callback):
class ArchitectureCheckpoint(Callback):
def __init__(self, checkpoint_dir, every="epoch"):
def __init__(self, checkpoint_dir):
super().__init__()
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))
_logger.info("Saving architecture to %s", dest_path)
self.trainer.export(dest_path)
class ModelCheckpoint(Callback):
def __init__(self, checkpoint_dir):
super().__init__()
assert every == "epoch"
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)))
if isinstance(self.model, nn.DataParallel):
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
_logger.info("Saving model to %s", dest_path)
torch.save(state_dict, dest_path)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import get_and_apply_next_architecture
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment