Unverified Commit f9ee589c authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #222 from microsoft/master

merge master
parents 36e6e350 4f3ee9cb
...@@ -19,19 +19,17 @@ export interface ParameterFileMeta { ...@@ -19,19 +19,17 @@ export interface ParameterFileMeta {
* PAI Training service Rest server, provides rest API to support pai job metrics update * PAI Training service Rest server, provides rest API to support pai job metrics update
* *
*/ */
@component.Singleton
export class PAIJobRestServer extends ClusterJobRestServer { export class PAIJobRestServer extends ClusterJobRestServer {
private parameterFileMetaList: ParameterFileMeta[] = []; protected parameterFileMetaList: ParameterFileMeta[] = [];
@Inject protected readonly paiTrainingService: PAITrainingService;
private readonly paiTrainingService: PAITrainingService;
/** /**
* constructor to provide NNIRestServer's own rest property, e.g. port * constructor to provide NNIRestServer's own rest property, e.g. port
*/ */
constructor() { constructor (paiTrainingService: PAITrainingService) {
super(); super();
this.paiTrainingService = component.get(PAITrainingService); this.paiTrainingService = paiTrainingService;
} }
protected handleTrialMetrics(jobId: string, metrics: any[]): void { protected handleTrialMetrics(jobId: string, metrics: any[]): void {
......
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import {TrialConfig} from '../../common/trialConfig';
/**
* PAI trial configuration
*/
export class NNIPAIK8STrialConfig extends TrialConfig {
public readonly cpuNum: number;
public readonly memoryMB: number;
public readonly image: string;
public virtualCluster?: string;
public readonly nniManagerNFSMountPath: string;
public readonly containerNFSMountPath: string;
public readonly paiStoragePlugin: string;
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, nniManagerNFSMountPath: string, containerNFSMountPath: string,
paiStoragePlugin: string, virtualCluster?: string) {
super(command, codeDir, gpuNum);
this.cpuNum = cpuNum;
this.memoryMB = memoryMB;
this.image = image;
this.virtualCluster = virtualCluster;
this.nniManagerNFSMountPath = nniManagerNFSMountPath;
this.containerNFSMountPath = containerNFSMountPath;
this.paiStoragePlugin = paiStoragePlugin;
}
}
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
export const PAI_INSTALL_NNI_SHELL_FORMAT: string =
`#!/bin/bash
if python3 -c 'import nni' > /dev/null 2>&1; then
# nni module is already installed, skip
return
else
# Install nni
python3 -m pip install --user nni
fi`;
export const PAI_K8S_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& ls $NNI_SYS_DIR \
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \
--nni_manager_version '{9}' --log_collection '{10}'`;
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import * as path from 'path';
// tslint:disable-next-line:no-implicit-dependencies
import * as request from 'request';
import * as component from '../../../common/component';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import {
HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../../common/trainingService';
import { delay, generateParamFileName,
getExperimentRootDir, getIPV4Address, getVersion, uniqueString, unixPathJoin } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { execMkdir, validateCodeDir, execCopydir } from '../../common/util';
import { PAI_K8S_TRIAL_COMMAND_FORMAT } from './paiK8SData';
import { NNIPAIK8STrialConfig } from './paiK8SConfig';
import { PAITrainingService } from '../paiTrainingService';
import { PAIClusterConfig, PAITrialJobDetail } from '../paiConfig';
import { PAIJobRestServer } from '../paiJobRestServer';
const yaml = require('js-yaml');
/**
* Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI
*/
@component.Singleton
class PAIK8STrainingService extends PAITrainingService {
protected paiTrialConfig: NNIPAIK8STrialConfig | undefined;
constructor() {
super();
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
if(this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if(this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
}
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
if (this.paiClusterConfig === undefined) {
this.log.error('pai cluster config is not initialized');
break;
}
this.paiTrialConfig = <NNIPAIK8STrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.paiTrialConfig.codeDir);
break;
case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True');
break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
this.log.error(`Uknown key: ${key}`);
}
}
//TODO: update trial parameters
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
return trialJobDetail;
}
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`);
}
if (this.paiTrialConfig === undefined) {
throw new Error(`paiTrialConfig not initialized!`);
}
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5);
//TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
const logPath: string = path.join(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId, trialJobId);
const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
trialJobId,
'WAITING',
paiJobName,
Date.now(),
trialWorkingFolder,
form,
logPath);
this.trialJobsMap.set(trialJobId, trialJobDetail);
this.jobQueue.push(trialJobId);
return trialJobDetail;
}
public generateJobConfigInYamlFormat(trialJobId: string, command: string) {
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
const jobName = `nni_exp_${this.experimentId}_trial_${trialJobId}`
const paiJobConfig: any = {
protocolVersion: 2,
name: jobName,
type: 'job',
jobRetryCount: 0,
prerequisites: [
{
type: 'dockerimage',
uri: this.paiTrialConfig.image,
name: 'docker_image_0'
}
],
taskRoles: {
taskrole: {
instances: 1,
completion: {
minFailedInstances: 1,
minSucceededInstances: -1
},
taskRetryCount: 0,
dockerImage: 'docker_image_0',
resourcePerInstance: {
gpu: this.paiTrialConfig.gpuNum,
cpu: this.paiTrialConfig.cpuNum,
memoryMB: this.paiTrialConfig.memoryMB
},
commands: [
command
]
}
},
extras: {
'com.microsoft.pai.runtimeplugin': [
{
plugin: this.paiTrialConfig.paiStoragePlugin
}
],
submitFrom: 'submit-job-v2'
}
}
if (this.paiTrialConfig.virtualCluster) {
paiJobConfig.defaults= {
virtualCluster: this.paiTrialConfig.virtualCluster
}
}
return yaml.safeDump(paiJobConfig);
}
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
const deferred: Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
}
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized');
}
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer is not initialized');
}
this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;
// Step 1. Prepare PAI job configuration
const trialLocalFolder: string = path.join(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId, trialJobId);
//create trial local working folder locally.
await execMkdir(trialLocalFolder);
const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local files
await fs.promises.writeFile(path.join(trialLocalFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local working folders
if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile(
path.join(trialLocalFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
}
//Copy codeDir files to local working folder
await execCopydir(this.paiTrialConfig.codeDir, trialLocalFolder);
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : '';
const containerWorkingDir: string = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/${trialJobId}`;
const nniPaiTrialCommand: string = String.Format(
PAI_K8S_TRIAL_COMMAND_FORMAT,
`${containerWorkingDir}`,
`${containerWorkingDir}/nnioutput`,
trialJobId,
this.experimentId,
trialJobDetail.form.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
this.paiRestServerPort,
version,
this.logCollection
)
.replace(/\r\n|\n|\r/gm, '');
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiJobConfig = this.generateJobConfigInYamlFormat(trialJobId, nniPaiTrialCommand);
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
method: 'POST',
body: paiJobConfig,
headers: {
'Content-Type': 'text/yaml',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;
this.log.error(errorMessage);
trialJobDetail.status = 'FAILED';
} else {
trialJobDetail.submitTime = Date.now();
}
deferred.resolve(true);
});
return deferred.promise;
}
}
export { PAIK8STrainingService };
...@@ -22,70 +22,96 @@ import { delay, generateParamFileName, ...@@ -22,70 +22,96 @@ import { delay, generateParamFileName,
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execMkdir, validateCodeDir } from '../common/util'; import { execMkdir, validateCodeDir } from '../common/util';
import { HDFSClientUtility } from './hdfsClientUtility';
import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig';
import { PAI_LOG_PATH_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAITrialJobDetail } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector'; import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer'; import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer';
import { PAIClusterConfig, PAITrialJobDetail } from './paiConfig';
import * as WebHDFS from 'webhdfs';
/** /**
* Training Service implementation for OpenPAI (Open Platform for AI) * Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI * Refer https://github.com/Microsoft/pai for more info about OpenPAI
*/ */
@component.Singleton @component.Singleton
class PAITrainingService implements TrainingService { abstract class PAITrainingService implements TrainingService {
private readonly log!: Logger; protected readonly log!: Logger;
private readonly metricsEmitter: EventEmitter; protected readonly metricsEmitter: EventEmitter;
private readonly trialJobsMap: Map<string, PAITrialJobDetail>; protected readonly trialJobsMap: Map<string, PAITrialJobDetail>;
private readonly expRootDir: string; protected readonly expRootDir: string;
private paiTrialConfig: NNIPAITrialConfig | undefined; protected paiClusterConfig?: PAIClusterConfig;
private paiClusterConfig?: PAIClusterConfig; protected readonly jobQueue: string[];
private readonly jobQueue: string[]; protected stopping: boolean = false;
private stopping: boolean = false; protected paiToken? : string;
private hdfsClient: any; protected paiTokenUpdateTime?: number;
private paiToken? : string; protected readonly paiTokenUpdateInterval: number;
private paiTokenUpdateTime?: number; protected readonly experimentId!: string;
private readonly paiTokenUpdateInterval: number; protected readonly paiJobCollector: PAIJobInfoCollector;
private readonly experimentId!: string; protected paiRestServerPort?: number;
private readonly paiJobCollector: PAIJobInfoCollector; protected nniManagerIpConfig?: NNIManagerIpConfig;
private paiRestServerPort?: number; protected versionCheck: boolean = true;
private nniManagerIpConfig?: NNIManagerIpConfig; protected logCollection: string;
private copyExpCodeDirPromise?: Promise<void>; protected isMultiPhase: boolean = false;
private copyAuthFilePromise?: Promise<void>; protected authFileHdfsPath: string | undefined = undefined;
private versionCheck: boolean = true; protected portList?: string | undefined;
private logCollection: string; protected paiJobRestServer?: PAIJobRestServer;
private isMultiPhase: boolean = false;
private authFileHdfsPath: string | undefined = undefined;
private portList?: string | undefined;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, PAITrialJobDetail>(); this.trialJobsMap = new Map<string, PAITrialJobDetail>();
this.jobQueue = []; this.jobQueue = [];
// Root dir on HDFS
this.expRootDir = path.join('/nni', 'experiments', getExperimentId()); this.expRootDir = path.join('/nni', 'experiments', getExperimentId());
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap); this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
this.paiTokenUpdateInterval = 7200000; //2hours this.paiTokenUpdateInterval = 7200000; //2hours
this.logCollection = 'none'; this.logCollection = 'none';
this.log.info('Construct OpenPAI training service.'); this.log.info('Construct paiBase training service.');
} }
public async run(): Promise<void> { public async run(): Promise<void> {
this.log.info('Run PAI training service.'); this.log.info('Run PAI training service.');
const restServer: PAIJobRestServer = component.get(PAIJobRestServer); if (this.paiJobRestServer === undefined) {
await restServer.start(); throw new Error('paiJobRestServer not initialized!');
restServer.setEnableVersionCheck = this.versionCheck; }
this.log.info(`PAI Training service rest server listening on: ${restServer.endPoint}`); await this.paiJobRestServer.start();
this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
await Promise.all([ await Promise.all([
this.statusCheckingLoop(), this.statusCheckingLoop(),
this.submitJobLoop()]); this.submitJobLoop()]);
this.log.info('PAI training service exit.'); this.log.info('PAI training service exit.');
} }
public async submitTrialJob(form: TrialJobApplicationForm): Promise<any> {
throw new Error('Not implemented!');
}
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
throw new Error('Not implemented!');
}
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
throw new Error('Not implemented!');
}
protected async submitJobLoop(): Promise<void> {
while (!this.stopping) {
while (!this.stopping && this.jobQueue.length > 0) {
const trialJobId: string = this.jobQueue[0];
if (await this.submitTrialJobToPAI(trialJobId)) {
// Remove trial job with trialJobId from job queue
this.jobQueue.shift();
} else {
// Break the while loop since failed to submitJob
break;
}
}
await delay(3000);
}
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
throw new Error('Not implemented!');
}
public async listTrialJobs(): Promise<TrialJobDetail[]> { public async listTrialJobs(): Promise<TrialJobDetail[]> {
const jobs: TrialJobDetail[] = []; const jobs: TrialJobDetail[] = [];
...@@ -93,7 +119,7 @@ class PAITrainingService implements TrainingService { ...@@ -93,7 +119,7 @@ class PAITrainingService implements TrainingService {
jobs.push(await this.getTrialJob(key)); jobs.push(await this.getTrialJob(key));
} }
return Promise.resolve(jobs); return jobs;
} }
public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> { public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
...@@ -104,10 +130,10 @@ class PAITrainingService implements TrainingService { ...@@ -104,10 +130,10 @@ class PAITrainingService implements TrainingService {
const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (paiTrialJob === undefined) { if (paiTrialJob === undefined) {
return Promise.reject(`trial job ${trialJobId} not found`); throw new Error(`trial job ${trialJobId} not found`);
} }
return Promise.resolve(paiTrialJob); return paiTrialJob;
} }
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void { public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
...@@ -118,53 +144,6 @@ class PAITrainingService implements TrainingService { ...@@ -118,53 +144,6 @@ class PAITrainingService implements TrainingService {
this.metricsEmitter.off('metric', listener); this.metricsEmitter.off('metric', listener);
} }
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`);
}
const deferred: Deferred<PAITrialJobDetail> = new Deferred<PAITrialJobDetail>();
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5);
//TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
const hdfsLogPath: string = String.Format(
PAI_LOG_PATH_FORMAT,
this.paiClusterConfig.host,
hdfsOutputDir
);
const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
trialJobId,
'WAITING',
paiJobName,
Date.now(),
trialWorkingFolder,
form,
hdfsLogPath);
this.trialJobsMap.set(trialJobId, trialJobDetail);
this.jobQueue.push(trialJobId);
deferred.resolve(trialJobDetail);
return deferred.promise;
}
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
await this.writeParameterFile(trialJobId, form.hyperParameters);
return trialJobDetail;
}
public get isMultiPhaseJobSupported(): boolean { public get isMultiPhaseJobSupported(): boolean {
return true; return true;
} }
...@@ -213,265 +192,31 @@ class PAITrainingService implements TrainingService { ...@@ -213,265 +192,31 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
public async setClusterMetadata(key: string, value: string): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
deferred.resolve();
break;
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.hdfsClient = WebHDFS.createClient({
user: this.paiClusterConfig.userName,
// Refer PAI document for Pylon mapping https://github.com/Microsoft/pai/tree/master/docs/pylon
port: 80,
path: '/webhdfs/api/v1',
host: this.paiClusterConfig.host
});
if(this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if(this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
} else {
deferred.reject(new Error('pai cluster config format error, please set password or token!'));
}
deferred.resolve();
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
if (this.paiClusterConfig === undefined) {
this.log.error('pai cluster config is not initialized');
deferred.reject(new Error('pai cluster config is not initialized'));
break;
}
this.paiTrialConfig = <NNIPAITrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
try {
await validateCodeDir(this.paiTrialConfig.codeDir);
} catch (error) {
this.log.error(error);
deferred.reject(new Error(error));
break;
}
// Copy experiment files from local folder to HDFS
this.copyExpCodeDirPromise = HDFSClientUtility.copyDirectoryToHdfs(
this.paiTrialConfig.codeDir,
HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName),
this.hdfsClient
);
// Upload authFile to hdfs
if (this.paiTrialConfig.authFile) {
this.authFileHdfsPath = unixPathJoin(HDFSClientUtility.hdfsExpRootDir(this.paiClusterConfig.userName), 'authFile');
this.copyAuthFilePromise = HDFSClientUtility.copyFileToHdfs(this.paiTrialConfig.authFile, this.authFileHdfsPath, this.hdfsClient);
}
deferred.resolve();
break;
case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True');
break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
throw new Error(`Uknown key: ${key}`);
}
return deferred.promise;
}
public getClusterMetadata(key: string): Promise<string> { public getClusterMetadata(key: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); throw new Error('Not implemented!');
deferred.resolve();
return deferred.promise;
} }
public async cleanUp(): Promise<void> { public async cleanUp(): Promise<void> {
this.log.info('Stopping PAI training service...'); this.log.info('Stopping PAI training service...');
this.stopping = true; this.stopping = true;
const deferred: Deferred<void> = new Deferred<void>(); if (this.paiJobRestServer === undefined) {
const restServer: PAIJobRestServer = component.get(PAIJobRestServer); throw new Error('paiJobRestServer not initialized!');
}
try { try {
await restServer.stop(); await this.paiJobRestServer.stop();
deferred.resolve();
this.log.info('PAI Training service rest server stopped successfully.'); this.log.info('PAI Training service rest server stopped successfully.');
} catch (error) { } catch (error) {
this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`); this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
deferred.reject(error);
} }
return deferred.promise;
} }
public get MetricsEmitter(): EventEmitter { public get MetricsEmitter(): EventEmitter {
return this.metricsEmitter; return this.metricsEmitter;
} }
private async submitTrialJobToPAI(trialJobId: string): Promise<boolean> { protected async statusCheckingLoop(): Promise<void> {
const deferred: Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
}
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized');
}
if (this.paiRestServerPort === undefined) {
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
this.paiRestServerPort = restServer.clusterRestServerPort;
}
// Make sure experiment code files is copied from local to HDFS
if (this.copyExpCodeDirPromise !== undefined) {
await this.copyExpCodeDirPromise;
}
//Make sure authFile is copied from local to HDFS
if (this.paiTrialConfig.authFile) {
await this.copyAuthFilePromise;
}
// Step 1. Prepare PAI job configuration
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
//create tmp trial working folder locally.
await execMkdir(trialLocalTempFolder);
const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local tmp folders
if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile(
path.join(trialLocalTempFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
}
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : '';
const nniPaiTrialCommand: string = String.Format(
PAI_TRIAL_COMMAND_FORMAT,
// PAI will copy job's codeDir into /root directory
`$PWD/${trialJobId}`,
`$PWD/${trialJobId}/nnioutput`,
trialJobId,
this.experimentId,
trialJobDetail.form.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
this.paiRestServerPort,
hdfsOutputDir,
this.paiClusterConfig.host,
this.paiClusterConfig.userName,
HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName),
version,
this.logCollection
)
.replace(/\r\n|\n|\r/gm, '');
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiTaskRoles: PAITaskRole[] = [
new PAITaskRole(
`nni_trail_${trialJobId}`,
// Task role number
1,
// Task CPU number
this.paiTrialConfig.cpuNum,
// Task memory
this.paiTrialConfig.memoryMB,
// Task GPU number
this.paiTrialConfig.gpuNum,
// Task command
nniPaiTrialCommand,
// Task shared memory
this.paiTrialConfig.shmMB,
// Task portList
this.paiTrialConfig.portList
)
];
const paiJobConfig: PAIJobConfig = new PAIJobConfig(
// Job name
trialJobDetail.paiJobName,
// Docker image
this.paiTrialConfig.image,
// codeDir
`$PAI_DEFAULT_FS_URI${hdfsCodeDir}`,
// PAI Task roles
paiTaskRoles,
// Add Virutal Cluster
this.paiTrialConfig.virtualCluster === undefined ? 'default' : this.paiTrialConfig.virtualCluster.toString(),
//Task auth File
this.authFileHdfsPath
);
// Step 2. Upload code files in codeDir onto HDFS
try {
await HDFSClientUtility.copyDirectoryToHdfs(trialLocalTempFolder, hdfsCodeDir, this.hdfsClient);
} catch (error) {
this.log.error(`PAI Training service: copy ${this.paiTrialConfig.codeDir} to HDFS ${hdfsCodeDir} failed, error is ${error}`);
trialJobDetail.status = 'FAILED'; // eslint-disable-line require-atomic-updates
deferred.resolve(true);
return deferred.promise;
}
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}/jobs`,
method: 'POST',
json: true,
body: paiJobConfig,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body.message}`;
trialJobDetail.status = 'FAILED';
deferred.resolve(true);
} else {
trialJobDetail.submitTime = Date.now();
deferred.resolve(true);
}
});
return deferred.promise;
}
private async statusCheckingLoop(): Promise<void> {
while (!this.stopping) { while (!this.stopping) {
if(this.paiClusterConfig && this.paiClusterConfig.passWord) { if(this.paiClusterConfig && this.paiClusterConfig.passWord) {
try { try {
...@@ -485,25 +230,11 @@ class PAITrainingService implements TrainingService { ...@@ -485,25 +230,11 @@ class PAITrainingService implements TrainingService {
} }
} }
await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig); await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig);
const restServer: PAIJobRestServer = component.get(PAIJobRestServer); if (this.paiJobRestServer === undefined) {
if (restServer.getErrorMessage !== undefined) { throw new Error('paiBaseJobRestServer not implemented!');
throw new Error(restServer.getErrorMessage);
} }
await delay(3000); if (this.paiJobRestServer.getErrorMessage !== undefined) {
} throw new Error(this.paiJobRestServer.getErrorMessage);
}
private async submitJobLoop(): Promise<void> {
while (!this.stopping) {
while (!this.stopping && this.jobQueue.length > 0) {
const trialJobId: string = this.jobQueue[0];
if (await this.submitTrialJobToPAI(trialJobId)) {
// Remove trial job with trialJobId from job queue
this.jobQueue.shift();
} else {
// Break the while loop since failed to submitJob
break;
}
} }
await delay(3000); await delay(3000);
} }
...@@ -512,7 +243,7 @@ class PAITrainingService implements TrainingService { ...@@ -512,7 +243,7 @@ class PAITrainingService implements TrainingService {
/** /**
* Update pai token by the interval time or initialize the pai token * Update pai token by the interval time or initialize the pai token
*/ */
private async updatePaiToken(): Promise<void> { protected async updatePaiToken(): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
const currentTime: number = new Date().getTime(); const currentTime: number = new Date().getTime();
...@@ -563,50 +294,6 @@ class PAITrainingService implements TrainingService { ...@@ -563,50 +294,6 @@ class PAITrainingService implements TrainingService {
return Promise.race([timeoutDelay, deferred.promise]) return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); }); .finally(() => { clearTimeout(timeoutId); });
} }
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('PAI trial config is not initialized');
}
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const hpFileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, hpFileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsHpFilePath: string = path.join(hdfsCodeDir, hpFileName);
await HDFSClientUtility.copyFileToHdfs(localFilepath, hdfsHpFilePath, this.hdfsClient);
await this.postParameterFileMeta({
experimentId: this.experimentId,
trialId: trialJobId,
filePath: hdfsHpFilePath
});
}
private postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
const req: request.Options = {
uri: `${restServer.endPoint}${restServer.apiRootUrl}/parameter-file-meta`,
method: 'POST',
json: true,
body: parameterFileMeta
};
request(req, (err: Error, res: request.Response) => {
if (err) {
deferred.reject(err);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
} }
export { PAITrainingService }; export { PAITrainingService };
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { getExperimentId } from '../../common/experimentStartupInfo'; import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger } from '../../common/log'; import { getLogger } from '../../../common/log';
import { unixPathJoin } from '../../common/utils'; import { unixPathJoin } from '../../../common/utils';
/** /**
* HDFS client utility, including copy file/directory * HDFS client utility, including copy file/directory
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import {TrialConfig} from '../../common/trialConfig';
/**
* Task role for PAI
*/
export class PAITaskRole {
// Name for the task role
public readonly name: string;
// Number of tasks for the task role, no less than 1
public readonly taskNumber: number;
// CPU number for one task in the task role, no less than 1
public readonly cpuNumber: number;
// Memory for one task in the task role, no less than 100
public readonly memoryMB: number;
// GPU number for one task in the task role, no less than 0
public readonly gpuNumber: number;
// Executable command for tasks in the task role, can not be empty
public readonly command: string;
//Shared memory for one task in the task role
public readonly shmMB?: number;
//portList to specify the port used in container
public portList?: PortListMetaData[];
/**
* Constructor
* @param name Name for the task role
* @param taskNumber Number of tasks for the task role, no less than 1
* @param cpuNumber CPU number for one task in the task role, no less than 1
* @param memoryMB Memory for one task in the task role, no less than 100
* @param gpuNumber GPU number for one task in the task role, no less than 0
* @param command Executable command for tasks in the task role, can not be empty
*/
constructor(name: string, taskNumber: number, cpuNumber: number, memoryMB: number, gpuNumber: number,
command: string, shmMB?: number, portList?: PortListMetaData[]) {
this.name = name;
this.taskNumber = taskNumber;
this.cpuNumber = cpuNumber;
this.memoryMB = memoryMB;
this.gpuNumber = gpuNumber;
this.command = command;
this.shmMB = shmMB;
this.portList = portList;
}
}
/**
* Trial job configuration submitted to PAI
*/
export class PAIJobConfig {
// Name for the job, need to be unique
public readonly jobName: string;
// URL pointing to the Docker image for all tasks in the job
public readonly image: string;
// Code directory on HDFS
public readonly codeDir: string;
//authentication file used for private Docker registry
public readonly authFile?: string;
// List of taskRole, one task role at least
public taskRoles: PAITaskRole[];
//The virtual cluster job runs on.
public readonly virtualCluster: string;
/**
* Constructor
* @param jobName Name for the job, need to be unique
* @param image URL pointing to the Docker image for all tasks in the job
* @param dataDir Data directory existing on HDFS
* @param outputDir Output directory on HDFS
* @param taskRoles List of taskRole, one task role at least
*/
constructor(jobName: string, image: string, codeDir: string,
taskRoles: PAITaskRole[], virtualCluster: string, authFile?: string) {
this.jobName = jobName;
this.image = image;
this.codeDir = codeDir;
this.taskRoles = taskRoles;
this.virtualCluster = virtualCluster;
this.authFile = authFile;
}
}
/**
* portList data structure used in PAI taskRole
*/
export class PortListMetaData {
public readonly label: string = '';
public readonly beginAt: number = 0;
public readonly portNumber: number = 0;
}
/**
* PAI trial configuration
*/
export class NNIPAITrialConfig extends TrialConfig {
public readonly cpuNum: number;
public readonly memoryMB: number;
public readonly image: string;
//The virtual cluster job runs on. If omitted, the job will run on default virtual cluster
public virtualCluster?: string;
//Shared memory for one task in the task role
public shmMB?: number;
//authentication file used for private Docker registry
public authFile?: string;
//portList to specify the port used in container
public portList?: PortListMetaData[];
constructor(command: string, codeDir: string, gpuNum: number, cpuNum: number, memoryMB: number,
image: string, virtualCluster?: string, shmMB?: number, authFile?: string, portList?: PortListMetaData[]) {
super(command, codeDir, gpuNum);
this.cpuNum = cpuNum;
this.memoryMB = memoryMB;
this.image = image;
this.virtualCluster = virtualCluster;
this.shmMB = shmMB;
this.authFile = authFile;
this.portList = portList;
}
}
...@@ -3,37 +3,7 @@ ...@@ -3,37 +3,7 @@
'use strict'; 'use strict';
import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; import { TrialJobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../../common/trainingService';
/**
* PAI trial job detail
*/
export class PAITrialJobDetail implements TrialJobDetail {
public id: string;
public status: TrialJobStatus;
public paiJobName: string;
public submitTime: number;
public startTime?: number;
public endTime?: number;
public tags?: string[];
public url?: string;
public workingDirectory: string;
public form: TrialJobApplicationForm;
public hdfsLogPath: string;
public isEarlyStopped?: boolean;
constructor(id: string, status: TrialJobStatus, paiJobName: string,
submitTime: number, workingDirectory: string, form: TrialJobApplicationForm, hdfsLogPath: string) {
this.id = id;
this.status = status;
this.paiJobName = paiJobName;
this.submitTime = submitTime;
this.workingDirectory = workingDirectory;
this.form = form;
this.tags = [];
this.hdfsLogPath = hdfsLogPath;
}
}
export const PAI_INSTALL_NNI_SHELL_FORMAT: string = export const PAI_INSTALL_NNI_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
...@@ -46,7 +16,7 @@ else ...@@ -46,7 +16,7 @@ else
fi`; fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string = export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \ `export NNI_PLATFORM=paiYarn NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& cd $NNI_SYS_DIR && sh install_nni.sh \ && cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \ && python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \
--pai_hdfs_output_dir '{9}' --pai_hdfs_host '{10}' --pai_user_name {11} --nni_hdfs_exp_dir '{12}' --webhdfs_path '/webhdfs/api/v1' \ --pai_hdfs_output_dir '{9}' --pai_hdfs_host '{10}' --pai_user_name {11} --nni_hdfs_exp_dir '{12}' --webhdfs_path '/webhdfs/api/v1' \
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import * as request from 'request';
import * as component from '../../../common/component';
import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import {
HyperParameters, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../../common/trainingService';
import { delay, generateParamFileName,
getExperimentRootDir, getIPV4Address, getVersion, uniqueString, unixPathJoin } from '../../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { execMkdir, validateCodeDir } from '../../common/util';
import { HDFSClientUtility } from './hdfsClientUtility';
import { NNIPAITrialConfig, PAIJobConfig, PAITaskRole } from './paiYarnConfig';
import { PAI_LOG_PATH_FORMAT, PAI_TRIAL_COMMAND_FORMAT } from './paiYarnData';
import { PAIJobInfoCollector } from '../paiJobInfoCollector';
import { PAITrainingService } from '../paiTrainingService';
import { PAIClusterConfig, PAITrialJobDetail } from '../paiConfig';
import * as WebHDFS from 'webhdfs';
import { PAIJobRestServer, ParameterFileMeta } from '../paiJobRestServer';
/**
* Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI
*/
@component.Singleton
class PAIYarnTrainingService extends PAITrainingService {
private hdfsClient: any;
private copyExpCodeDirPromise?: Promise<void>;
private copyAuthFilePromise?: Promise<void>;
private paiTrialConfig?: NNIPAITrialConfig;
constructor() {
super();
}
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiBaseClusterConfig not initialized!`);
}
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5);
//TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
const hdfsLogPath: string = String.Format(
PAI_LOG_PATH_FORMAT,
this.paiClusterConfig.host,
hdfsOutputDir
);
const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
trialJobId,
'WAITING',
paiJobName,
Date.now(),
trialWorkingFolder,
form,
hdfsLogPath);
this.trialJobsMap.set(trialJobId, trialJobDetail);
this.jobQueue.push(trialJobId);
return trialJobDetail;
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.PAI_YARN_CLUSTER_CONFIG:
this.paiJobRestServer = new PAIJobRestServer(component.get(PAIYarnTrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.hdfsClient = WebHDFS.createClient({
user: this.paiClusterConfig.userName,
// Refer PAI document for Pylon mapping https://github.com/Microsoft/pai/tree/master/docs/pylon
port: 80,
path: '/webhdfs/api/v1',
host: this.paiClusterConfig.host
});
if(this.paiClusterConfig.passWord) {
// Get PAI authentication token
await this.updatePaiToken();
} else if(this.paiClusterConfig.token) {
this.paiToken = this.paiClusterConfig.token;
} else {
throw new Error('pai cluster config format error, please set password or token!');
}
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
if (this.paiClusterConfig === undefined) {
this.log.error('pai cluster config is not initialized');
break;
}
this.paiTrialConfig = <NNIPAITrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.paiTrialConfig.codeDir);
// Copy experiment files from local folder to HDFS
this.copyExpCodeDirPromise = HDFSClientUtility.copyDirectoryToHdfs(
this.paiTrialConfig.codeDir,
HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName),
this.hdfsClient
);
// Upload authFile to hdfs
if (this.paiTrialConfig.authFile) {
this.authFileHdfsPath = unixPathJoin(HDFSClientUtility.hdfsExpRootDir(this.paiClusterConfig.userName), 'authFile');
this.copyAuthFilePromise = HDFSClientUtility.copyFileToHdfs(this.paiTrialConfig.authFile, this.authFileHdfsPath, this.hdfsClient);
}
break;
case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True');
break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
throw new Error(`Uknown key: ${key}`);
}
}
protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
const deferred: Deferred<boolean> = new Deferred<boolean>();
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
}
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized');
}
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer is not initialized');
}
this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;
// Make sure experiment code files is copied from local to HDFS
if (this.copyExpCodeDirPromise !== undefined) {
await this.copyExpCodeDirPromise;
}
//Make sure authFile is copied from local to HDFS
if (this.paiTrialConfig.authFile) {
await this.copyAuthFilePromise;
}
// Step 1. Prepare PAI job configuration
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
//create tmp trial working folder locally.
await execMkdir(trialLocalTempFolder);
const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local tmp folders
if (trialJobDetail.form !== undefined) {
await fs.promises.writeFile(
path.join(trialLocalTempFolder, generateParamFileName(trialJobDetail.form.hyperParameters)),
trialJobDetail.form.hyperParameters.value, { encoding: 'utf8' }
);
}
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsOutputDir: string = unixPathJoin(hdfsCodeDir, 'nnioutput');
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const version: string = this.versionCheck ? await getVersion() : '';
const nniPaiTrialCommand: string = String.Format(
PAI_TRIAL_COMMAND_FORMAT,
// PAI will copy job's codeDir into /root directory
`$PWD/${trialJobId}`,
`$PWD/${trialJobId}/nnioutput`,
trialJobId,
this.experimentId,
trialJobDetail.form.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
this.paiRestServerPort,
hdfsOutputDir,
this.paiClusterConfig.host,
this.paiClusterConfig.userName,
HDFSClientUtility.getHdfsExpCodeDir(this.paiClusterConfig.userName),
version,
this.logCollection
)
.replace(/\r\n|\n|\r/gm, '');
this.log.info(`nniPAItrial command is ${nniPaiTrialCommand.trim()}`);
const paiTaskRoles: PAITaskRole[] = [
new PAITaskRole(
`nni_trail_${trialJobId}`,
// Task role number
1,
// Task CPU number
this.paiTrialConfig.cpuNum,
// Task memory
this.paiTrialConfig.memoryMB,
// Task GPU number
this.paiTrialConfig.gpuNum,
// Task command
nniPaiTrialCommand,
// Task shared memory
this.paiTrialConfig.shmMB,
// Task portList
this.paiTrialConfig.portList
)
];
const paiJobConfig: PAIJobConfig = new PAIJobConfig(
// Job name
trialJobDetail.paiJobName,
// Docker image
this.paiTrialConfig.image,
// codeDir
`$PAI_DEFAULT_FS_URI${hdfsCodeDir}`,
// PAI Task roles
paiTaskRoles,
// Add Virutal Cluster
this.paiTrialConfig.virtualCluster === undefined ? 'default' : this.paiTrialConfig.virtualCluster.toString(),
//Task auth File
this.authFileHdfsPath
);
// Step 2. Upload code files in codeDir onto HDFS
try {
await HDFSClientUtility.copyDirectoryToHdfs(trialLocalTempFolder, hdfsCodeDir, this.hdfsClient);
} catch (error) {
this.log.error(`PAI Training service: copy ${this.paiTrialConfig.codeDir} to HDFS ${hdfsCodeDir} failed, error is ${error}`);
trialJobDetail.status = 'FAILED'; // eslint-disable-line require-atomic-updates
return true;
}
// Step 3. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = {
uri: `http://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}/jobs`,
method: 'POST',
json: true,
body: paiJobConfig,
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.paiToken}`
}
};
request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
if ((error !== undefined && error !== null) || response.statusCode >= 400) {
const errorMessage: string = (error !== undefined && error !== null) ? error.message :
`Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${response.body.message}`;
trialJobDetail.status = 'FAILED';
deferred.resolve(true);
} else {
trialJobDetail.submitTime = Date.now();
deferred.resolve(true);
}
});
return deferred.promise;
}
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
await this.writeParameterFile(trialJobId, form.hyperParameters);
return trialJobDetail;
}
protected async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('PAI trial config is not initialized');
}
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const hpFileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, hpFileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsHpFilePath: string = path.join(hdfsCodeDir, hpFileName);
await HDFSClientUtility.copyFileToHdfs(localFilepath, hdfsHpFilePath, this.hdfsClient);
await this.postParameterFileMeta({
experimentId: this.experimentId,
trialId: trialJobId,
filePath: hdfsHpFilePath
});
}
protected postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer not implemented!');
}
const req: request.Options = {
uri: `${this.paiJobRestServer.endPoint}${this.paiJobRestServer.apiRootUrl}/parameter-file-meta`,
method: 'POST',
json: true,
body: parameterFileMeta
};
request(req, (err: Error, res: request.Response) => {
if (err) {
deferred.reject(err);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
}
export { PAIYarnTrainingService };
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
'use strict'; 'use strict';
import {TrialConfig} from '../common/trialConfig'; import {TrialConfig} from '../../common/trialConfig';
/** /**
* PAI configuration to run trials * PAI configuration to run trials
*/ */
export class PAITrialConfig extends TrialConfig { export class PAIYarnTrialConfig extends TrialConfig {
public readonly cpuNum: number; public readonly cpuNum: number;
public readonly memoryMB: number; public readonly memoryMB: number;
public readonly image: string; public readonly image: string;
......
...@@ -9,7 +9,7 @@ import * as os from 'os'; ...@@ -9,7 +9,7 @@ import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import * as tmp from 'tmp'; import * as tmp from 'tmp';
import { cleanupUnitTest, prepareUnitTest, uniqueString } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest, uniqueString } from '../../common/utils';
import { HDFSClientUtility } from '../pai/hdfsClientUtility'; import { HDFSClientUtility } from '../pai/paiYarn/hdfsClientUtility';
var WebHDFS = require('webhdfs'); var WebHDFS = require('webhdfs');
var rmdir = require('rmdir'); var rmdir = require('rmdir');
......
...@@ -11,14 +11,14 @@ import * as component from '../../common/component'; ...@@ -11,14 +11,14 @@ import * as component from '../../common/component';
import { TrialJobApplicationForm } from '../../common/trainingService'; import { TrialJobApplicationForm } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAITrainingService } from '../pai/paiTrainingService'; import { PAIYarnTrainingService } from '../pai/paiYarn/paiYarnTrainingService';
// TODO: copy mockedTrail.py to local folder // TODO: copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name const localCodeDir: string = tmp.dirSync().name
const mockedTrialPath: string = './training_service/test/mockedTrial.py' const mockedTrialPath: string = './training_service/test/mockedTrial.py'
fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py') fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
describe('Unit Test for PAITrainingService', () => { describe('Unit Test for PAIYarnTrainingService', () => {
let skip: boolean = false; let skip: boolean = false;
let testPaiClusterInfo: any; let testPaiClusterInfo: any;
let paiCluster: any; let paiCluster: any;
...@@ -33,7 +33,7 @@ describe('Unit Test for PAITrainingService', () => { ...@@ -33,7 +33,7 @@ describe('Unit Test for PAITrainingService', () => {
skip = true; skip = true;
} }
let paiTrainingService: PAITrainingService; let paiYarnTrainingService: PAIYarnTrainingService;
console.log(tmp.dirSync().name); console.log(tmp.dirSync().name);
...@@ -51,15 +51,15 @@ describe('Unit Test for PAITrainingService', () => { ...@@ -51,15 +51,15 @@ describe('Unit Test for PAITrainingService', () => {
if (skip) { if (skip) {
return; return;
} }
paiTrainingService = component.get(PAITrainingService); paiYarnTrainingService = component.get(PAIYarnTrainingService);
paiTrainingService.run(); paiYarnTrainingService.run();
}); });
afterEach(() => { afterEach(() => {
if (skip) { if (skip) {
return; return;
} }
paiTrainingService.cleanUp(); paiYarnTrainingService.cleanUp();
}); });
it('Get PAI token', async () => { it('Get PAI token', async () => {
...@@ -67,14 +67,14 @@ describe('Unit Test for PAITrainingService', () => { ...@@ -67,14 +67,14 @@ describe('Unit Test for PAITrainingService', () => {
return; return;
} }
console.log(`paiCluster is ${paiCluster}`) console.log(`paiCluster is ${paiCluster}`)
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.PAI_CLUSTER_CONFIG, paiCluster); await paiYarnTrainingService.setClusterMetadata(TrialConfigMetadataKey.PAI_YARN_CLUSTER_CONFIG, paiCluster);
await paiTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, paiTrialConfig); await paiYarnTrainingService.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, paiTrialConfig);
const form: TrialJobApplicationForm = { const form: TrialJobApplicationForm = {
sequenceId: 0, sequenceId: 0,
hyperParameters: { value: '', index: 0 } hyperParameters: { value: '', index: 0 }
}; };
try { try {
const trialDetail = await paiTrainingService.submitTrialJob(form); const trialDetail = await paiYarnTrainingService.submitTrialJob(form);
chai.expect(trialDetail.status).to.be.equals('WAITING'); chai.expect(trialDetail.status).to.be.equals('WAITING');
} catch(error) { } catch(error) {
console.log('Submit job failed:' + error); console.log('Submit job failed:' + error);
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .trial import * from .env_vars import dispatcher_env_vars
from .smartparam import *
from .nas_utils import training_update if dispatcher_env_vars.SDK_PROCESS != 'dispatcher':
from .trial import *
from .smartparam import *
from .nas_utils import training_update
class NoMoreTrialError(Exception): class NoMoreTrialError(Exception):
def __init__(self, ErrorInfo): def __init__(self, ErrorInfo):
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
'''
__main__.py
'''
import os import os
import sys import sys
import argparse import argparse
import logging import logging
import json import json
import importlib import importlib
import base64
from .common import enable_multi_thread, enable_multi_phase from .common import enable_multi_thread, enable_multi_phase
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
...@@ -29,99 +27,67 @@ def augment_classargs(input_class_args, classname): ...@@ -29,99 +27,67 @@ def augment_classargs(input_class_args, classname):
input_class_args[key] = value input_class_args[key] = value
return input_class_args return input_class_args
def create_builtin_class_instance(classname, jsonstr_args, is_advisor=False):
if is_advisor: def create_builtin_class_instance(class_name, class_args, builtin_module_dict, builtin_class_dict):
if classname not in AdvisorModuleName or \ if class_name not in builtin_module_dict or \
importlib.util.find_spec(AdvisorModuleName[classname]) is None: importlib.util.find_spec(builtin_module_dict[class_name]) is None:
raise RuntimeError('Advisor module is not found: {}'.format(classname)) raise RuntimeError('Builtin module is not found: {}'.format(class_name))
class_module = importlib.import_module(AdvisorModuleName[classname]) class_module = importlib.import_module(builtin_module_dict[class_name])
class_constructor = getattr(class_module, AdvisorClassName[classname]) class_constructor = getattr(class_module, builtin_class_dict[class_name])
else:
if classname not in ModuleName or \ if class_args is None:
importlib.util.find_spec(ModuleName[classname]) is None: class_args = {}
raise RuntimeError('Tuner module is not found: {}'.format(classname)) class_args = augment_classargs(class_args, class_name)
class_module = importlib.import_module(ModuleName[classname]) instance = class_constructor(**class_args)
class_constructor = getattr(class_module, ClassName[classname])
if jsonstr_args:
class_args = augment_classargs(json.loads(jsonstr_args), classname)
else:
class_args = augment_classargs({}, classname)
if class_args:
instance = class_constructor(**class_args)
else:
instance = class_constructor()
return instance return instance
def create_customized_class_instance(class_dir, class_filename, classname, jsonstr_args):
if not os.path.isfile(os.path.join(class_dir, class_filename)): def create_customized_class_instance(class_params):
code_dir = class_params.get('codeDir')
class_filename = class_params.get('classFileName')
class_name = class_params.get('className')
class_args = class_params.get('classArgs')
if not os.path.isfile(os.path.join(code_dir, class_filename)):
raise ValueError('Class file not found: {}'.format( raise ValueError('Class file not found: {}'.format(
os.path.join(class_dir, class_filename))) os.path.join(code_dir, class_filename)))
sys.path.append(class_dir) sys.path.append(code_dir)
module_name = os.path.splitext(class_filename)[0] module_name = os.path.splitext(class_filename)[0]
class_module = importlib.import_module(module_name) class_module = importlib.import_module(module_name)
class_constructor = getattr(class_module, classname) class_constructor = getattr(class_module, class_name)
if jsonstr_args:
class_args = json.loads(jsonstr_args) if class_args is None:
instance = class_constructor(**class_args) class_args = {}
else: instance = class_constructor(**class_args)
instance = class_constructor()
return instance return instance
def parse_args():
parser = argparse.ArgumentParser(description='parse command line parameters.')
parser.add_argument('--advisor_class_name', type=str, required=False,
help='Advisor class name, the class must be a subclass of nni.MsgDispatcherBase')
parser.add_argument('--advisor_class_filename', type=str, required=False,
help='Advisor class file path')
parser.add_argument('--advisor_args', type=str, required=False,
help='Parameters pass to advisor __init__ constructor')
parser.add_argument('--advisor_directory', type=str, required=False,
help='Advisor directory')
parser.add_argument('--tuner_class_name', type=str, required=False,
help='Tuner class name, the class must be a subclass of nni.Tuner')
parser.add_argument('--tuner_class_filename', type=str, required=False,
help='Tuner class file path')
parser.add_argument('--tuner_args', type=str, required=False,
help='Parameters pass to tuner __init__ constructor')
parser.add_argument('--tuner_directory', type=str, required=False,
help='Tuner directory')
parser.add_argument('--assessor_class_name', type=str, required=False,
help='Assessor class name, the class must be a subclass of nni.Assessor')
parser.add_argument('--assessor_args', type=str, required=False,
help='Parameters pass to assessor __init__ constructor')
parser.add_argument('--assessor_directory', type=str, required=False,
help='Assessor directory')
parser.add_argument('--assessor_class_filename', type=str, required=False,
help='Assessor class file path')
parser.add_argument('--multi_phase', action='store_true')
parser.add_argument('--multi_thread', action='store_true')
flags, _ = parser.parse_known_args()
return flags
def main(): def main():
''' parser = argparse.ArgumentParser(description='Dispatcher command line parser')
main function. parser.add_argument('--exp_params', type=str, required=True)
''' args, _ = parser.parse_known_args()
exp_params_decode = base64.b64decode(args.exp_params).decode('utf-8')
logger.debug('decoded exp_params: [%s]', exp_params_decode)
exp_params = json.loads(exp_params_decode)
logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4))
args = parse_args() if exp_params.get('multiThread'):
if args.multi_thread:
enable_multi_thread() enable_multi_thread()
if args.multi_phase: if exp_params.get('multiPhase'):
enable_multi_phase() enable_multi_phase()
if args.advisor_class_name: if exp_params.get('advisor') is not None:
# advisor is enabled and starts to run # advisor is enabled and starts to run
_run_advisor(args) _run_advisor(exp_params)
else: else:
# tuner (and assessor) is enabled and starts to run # tuner (and assessor) is enabled and starts to run
tuner = _create_tuner(args) assert exp_params.get('tuner') is not None
if args.assessor_class_name: tuner = _create_tuner(exp_params)
assessor = _create_assessor(args) if exp_params.get('assessor') is not None:
assessor = _create_assessor(exp_params)
else: else:
assessor = None assessor = None
dispatcher = MsgDispatcher(tuner, assessor) dispatcher = MsgDispatcher(tuner, assessor)
...@@ -139,17 +105,14 @@ def main(): ...@@ -139,17 +105,14 @@ def main():
raise raise
def _run_advisor(args): def _run_advisor(exp_params):
if args.advisor_class_name in AdvisorModuleName: if exp_params.get('advisor').get('builtinAdvisorName') in AdvisorModuleName:
dispatcher = create_builtin_class_instance( dispatcher = create_builtin_class_instance(
args.advisor_class_name, exp_params.get('advisor').get('builtinAdvisorName'),
args.advisor_args, True) exp_params.get('advisor').get('classArgs'),
AdvisorModuleName, AdvisorClassName)
else: else:
dispatcher = create_customized_class_instance( dispatcher = create_customized_class_instance(exp_params.get('advisor'))
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
if dispatcher is None: if dispatcher is None:
raise AssertionError('Failed to create Advisor instance') raise AssertionError('Failed to create Advisor instance')
try: try:
...@@ -159,33 +122,27 @@ def _run_advisor(args): ...@@ -159,33 +122,27 @@ def _run_advisor(args):
raise raise
def _create_tuner(args): def _create_tuner(exp_params):
if args.tuner_class_name in ModuleName: if exp_params.get('tuner').get('builtinTunerName') in ModuleName:
tuner = create_builtin_class_instance( tuner = create_builtin_class_instance(
args.tuner_class_name, exp_params.get('tuner').get('builtinTunerName'),
args.tuner_args) exp_params.get('tuner').get('classArgs'),
ModuleName, ClassName)
else: else:
tuner = create_customized_class_instance( tuner = create_customized_class_instance(exp_params.get('tuner'))
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)
if tuner is None: if tuner is None:
raise AssertionError('Failed to create Tuner instance') raise AssertionError('Failed to create Tuner instance')
return tuner return tuner
def _create_assessor(args): def _create_assessor(exp_params):
if args.assessor_class_name in ModuleName: if exp_params.get('assessor').get('builtinAssessorName') in ModuleName:
assessor = create_builtin_class_instance( assessor = create_builtin_class_instance(
args.assessor_class_name, exp_params.get('assessor').get('builtinAssessorName'),
args.assessor_args) exp_params.get('assessor').get('classArgs'),
ModuleName, ClassName)
else: else:
assessor = create_customized_class_instance( assessor = create_customized_class_instance(exp_params.get('assessor'))
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
if assessor is None: if assessor is None:
raise AssertionError('Failed to create Assessor instance') raise AssertionError('Failed to create Assessor instance')
return assessor return assessor
......
...@@ -100,7 +100,7 @@ def get_bits_length(config, quant_type): ...@@ -100,7 +100,7 @@ def get_bits_length(config, quant_type):
class QAT_Quantizer(Quantizer): class QAT_Quantizer(Quantizer):
"""Quantizer using the DoReFa scheme, as defined in: """Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
""" """
...@@ -227,20 +227,17 @@ class DoReFaQuantizer(Quantizer): ...@@ -227,20 +227,17 @@ class DoReFaQuantizer(Quantizer):
(https://arxiv.org/abs/1606.06160) (https://arxiv.org/abs/1606.06160)
""" """
def __init__(self, model, config_list): def __init__(self, model, config_list):
"""
config_list: supported keys:
- q_bits
"""
super().__init__(model, config_list) super().__init__(model, config_list)
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, config, **kwargs):
weight_bits = get_bits_length(config, 'weight')
out = weight.tanh() out = weight.tanh()
out = out / (2 * out.abs().max()) + 0.5 out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, config['q_bits']) out = self.quantize(out, weight_bits)
out = 2 * out -1 out = 2 * out -1
return out return out
def quantize(self, input_ri, q_bits): def quantize(self, input_ri, q_bits):
scale = pow(2, q_bits)-1 scale = pow(2, q_bits)-1
output = torch.round(input_ri*scale)/scale output = torch.round(input_ri*scale)/scale
return output return output
\ No newline at end of file
...@@ -250,6 +250,10 @@ class Quantizer(Compressor): ...@@ -250,6 +250,10 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer Base quantizer for pytorch quantizer
""" """
def __init__(self, model, config_list):
super().__init__(model, config_list)
self.quant_grad = QuantGrad
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
""" """
quantize should overload this method to quantize weight. quantize should overload this method to quantize weight.
...@@ -262,7 +266,7 @@ class Quantizer(Compressor): ...@@ -262,7 +266,7 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for weight quantization the configuration for weight quantization
""" """
raise NotImplementedError("Quantizer must overload quantize_weight()") raise NotImplementedError('Quantizer must overload quantize_weight()')
def quantize_output(self, output, config, op, op_type, op_name): def quantize_output(self, output, config, op, op_type, op_name):
""" """
...@@ -276,7 +280,7 @@ class Quantizer(Compressor): ...@@ -276,7 +280,7 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for output quantization the configuration for output quantization
""" """
raise NotImplementedError("Quantizer must overload quantize_output()") raise NotImplementedError('Quantizer must overload quantize_output()')
def quantize_input(self, *inputs, config, op, op_type, op_name): def quantize_input(self, *inputs, config, op, op_type, op_name):
""" """
...@@ -290,7 +294,7 @@ class Quantizer(Compressor): ...@@ -290,7 +294,7 @@ class Quantizer(Compressor):
config : dict config : dict
the configuration for inputs quantization the configuration for inputs quantization
""" """
raise NotImplementedError("Quantizer must overload quantize_input()") raise NotImplementedError('Quantizer must overload quantize_input()')
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
...@@ -305,62 +309,93 @@ class Quantizer(Compressor): ...@@ -305,62 +309,93 @@ class Quantizer(Compressor):
the configuration for quantization the configuration for quantization
""" """
assert layer._forward is None, 'Each model can only be compressed once' assert layer._forward is None, 'Each model can only be compressed once'
assert "quant_types" in config, 'must provide quant_types in config' assert 'quant_types' in config, 'must provide quant_types in config'
assert isinstance(config["quant_types"], list), 'quant_types must be list type' assert isinstance(config['quant_types'], list), 'quant_types must be list type'
assert "quant_bits" in config, 'must provide quant_bits in config' assert 'quant_bits' in config, 'must provide quant_bits in config'
assert isinstance(config["quant_bits"], int) or isinstance(config["quant_bits"], dict), 'quant_bits must be dict type or int type' assert isinstance(config['quant_bits'], int) or isinstance(config['quant_bits'], dict), 'quant_bits must be dict type or int type'
if isinstance(config["quant_bits"], dict): if isinstance(config['quant_bits'], dict):
for quant_type in config["quant_types"]: for quant_type in config['quant_types']:
assert quant_type in config["quant_bits"], 'bits length for %s must be specified in quant_bits dict' % quant_type assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type
if 'weight' in config["quant_types"]: if 'weight' in config['quant_types']:
if not _check_weight(layer.module): if not _check_weight(layer.module):
_logger.warning('Module %s does not have parameter "weight"', layer.name) _logger.warning('Module %s does not have parameter "weight"', layer.name)
else:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer.module.register_parameter('old_weight', torch.nn.Parameter(layer.module.weight))
delattr(layer.module, 'weight')
layer.module.register_buffer('weight', layer.module.old_weight)
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*inputs): def new_forward(*inputs):
if 'input' in config["quant_types"]: if 'input' in config['quant_types']:
inputs = straight_through_quantize_input.apply(inputs, self, config, layer) inputs = self.quant_grad.apply(inputs, QuantType.QUANT_INPUT, self.quantize_input, config, layer)
if 'weight' in config["quant_types"] and _check_weight(layer.module): if 'weight' in config['quant_types'] and _check_weight(layer.module):
weight = layer.module.weight.data new_weight = self.quant_grad.apply(layer.module.old_weight, QuantType.QUANT_WEIGHT, self.quantize_weight, config, layer)
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) layer.module.weight = new_weight
layer.module.weight.data = new_weight
result = layer._forward(*inputs) result = layer._forward(*inputs)
layer.module.weight.data = weight
else: else:
result = layer._forward(*inputs) result = layer._forward(*inputs)
if 'output' in config["quant_types"]: if 'output' in config['quant_types']:
result = straight_through_quantize_output.apply(result, self, config, layer) result = self.quant_grad.apply(result, QuantType.QUANT_OUTPUT, self.quantize_output, config, layer)
return result return result
layer.module.forward = new_forward layer.module.forward = new_forward
class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT = 0
QUANT_WEIGHT = 1
QUANT_OUTPUT = 2
class straight_through_quantize_output(torch.autograd.Function): class QuantGrad(torch.autograd.Function):
"""
Base class for overriding backward function of quantization operation.
"""
@staticmethod @staticmethod
def forward(ctx, output, quantizer, config, layer): def quant_backward(tensor, grad_output, quant_type):
return quantizer.quantize_output(output, config, op=layer.module, op_type=layer.type, op_name=layer.name) """
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
@staticmethod Parameters
def backward(ctx, grad_output): ----------
# Straight-through estimator tensor : Tensor
return grad_output, None, None, None input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
class straight_through_quantize_input(torch.autograd.Function): Returns
@staticmethod -------
def forward(ctx, inputs, quantizer, config, layer): tensor
return quantizer.quantize_input(inputs, config, op=layer.module, op_type=layer.type, op_name=layer.name) gradient of the input of quantization operation
"""
return grad_output
@staticmethod @staticmethod
def backward(ctx, grad_output): def forward(ctx, tensor, quant_type, quant_func, config, layer):
# Straight-through estimator ctx.save_for_backward(tensor, torch.Tensor([quant_type]))
return grad_output, None, None, None return quant_func(tensor, config, op=layer.module, op_type=layer.type, op_name=layer.name)
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type = ctx.saved_variables
output = cls.quant_backward(tensor, grad_output, quant_type)
return output, None, None, None, None
def _check_weight(module): def _check_weight(module):
try: try:
return isinstance(module.weight, torch.nn.Parameter) and isinstance(module.weight.data, torch.Tensor) return isinstance(module.weight.data, torch.Tensor)
except AttributeError: except AttributeError:
return False return False
...@@ -16,6 +16,7 @@ _trial_env_var_names = [ ...@@ -16,6 +16,7 @@ _trial_env_var_names = [
] ]
_dispatcher_env_var_names = [ _dispatcher_env_var_names = [
'SDK_PROCESS',
'NNI_MODE', 'NNI_MODE',
'NNI_CHECKPOINT_DIRECTORY', 'NNI_CHECKPOINT_DIRECTORY',
'NNI_LOG_DIRECTORY', 'NNI_LOG_DIRECTORY',
......
...@@ -11,7 +11,6 @@ import logging ...@@ -11,7 +11,6 @@ import logging
import hyperopt as hp import hyperopt as hp
import numpy as np import numpy as np
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.nas_utils import rewrite_nas_space
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
logger = logging.getLogger('hyperopt_AutoML') logger = logging.getLogger('hyperopt_AutoML')
...@@ -226,7 +225,6 @@ class HyperoptTuner(Tuner): ...@@ -226,7 +225,6 @@ class HyperoptTuner(Tuner):
return hp.anneal.suggest return hp.anneal.suggest
raise RuntimeError('Not support tuner algorithm in hyperopt.') raise RuntimeError('Not support tuner algorithm in hyperopt.')
@rewrite_nas_space
def update_search_space(self, search_space): def update_search_space(self, search_space):
""" """
Update search space definition in tuner by search_space in parameters. Update search space definition in tuner by search_space in parameters.
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
import logging import logging
import os import os
import torch
import torch.nn as nn
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -44,11 +47,28 @@ class LRSchedulerCallback(Callback): ...@@ -44,11 +47,28 @@ class LRSchedulerCallback(Callback):
class ArchitectureCheckpoint(Callback): class ArchitectureCheckpoint(Callback):
def __init__(self, checkpoint_dir, every="epoch"): def __init__(self, checkpoint_dir):
super().__init__()
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))
_logger.info("Saving architecture to %s", dest_path)
self.trainer.export(dest_path)
class ModelCheckpoint(Callback):
def __init__(self, checkpoint_dir):
super().__init__() super().__init__()
assert every == "epoch"
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch): def on_epoch_end(self, epoch):
self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))) if isinstance(self.model, nn.DataParallel):
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
_logger.info("Saving model to %s", dest_path)
torch.save(state_dict, dest_path)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import get_and_apply_next_architecture from .mutator import get_and_apply_next_architecture
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment