Unverified Commit 39085789 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Multi-phase training service (#148)

* Dev enas  - multi-phase hyper parameters support (#96)

* Multi-phase support

* Updates

* Updates

* updates

* updates

* updates

* Merge master to dev-enas (#117)

* Multi-phase support

* update document (#92)

* Edit readme.md

* updated a word

* Update GetStarted.md

* Update GetStarted.md

* refact readme, getstarted and write your trial md.

* Update README.md

* Update WriteYourTrial.md

* Update WriteYourTrial.md

* Update WriteYourTrial.md

* Update WriteYourTrial.md

* Fix nnictl bugs and add new feature (#75)

* fix nnictl bug

* fix nnictl create bug

* add experiment status logic

* add more information for nnictl

* fix Evolution Tuner bug

* refactor code

* fix code in updater.py

* fix nnictl --help

* fix classArgs bug

* update check response.status_code logic

* Updates

* remove Buffer warning (#100)

* update readme in ga_squad

* update readme

* fix typo

* Update README.md

* Update README.md

* Update README.md

* Updates

* updates

* updates

* updates

* Add support for debugging mode

* fix setup.py (#115)

* Add DAG model configuration format for SQuAD example.

* Explain config format for SQuAD QA model.

* Add more detailed introduction about the evolution algorithm.

* Merge master to dev-enas (#118)

* update document (#92)

* Edit readme.md

* updated a word

* Update GetStarted.md

* Update GetStarted.md

* refact readme, getstarted and write your trial md.

* Update README.md

* Update WriteYourTrial.md

* Update WriteYourTrial.md

* Update WriteYourTrial.md

* Update WriteYourTrial.md

* Fix nnictl bugs and add new feature (#75)

* fix nnictl bug

* fix nnictl create bug

* add experiment status logic

* add more information for nnictl

* fix Evolution Tuner bug

* refactor code

* fix code in updater.py

* fix nnictl --help

* fix classArgs bug

* update check response.status_code logic

* remove Buffer warning (#100)

* update readme in ga_squad

* update readme

* fix typo

* Update README.md

* Update README.md

* Update README.md

* Add support for debugging mode

* fix setup.py (#115)

* Add DAG model configuration format for SQuAD example.

* Explain config format for SQuAD QA model.

* Add more detailed introduction about the evolution algorithm.

* Fix install.sh add add trial log path (#109)

* fix nnictl bug

* fix nnictl create bug

* add experiment status logic

* add more information for nnictl

* fix Evolution Tuner bug

* refactor code

* fix code in updater.py

* fix nnictl --help

* fix classArgs bug

* update check response.status_code logic

* show trial log path

* update document

* fix install.sh

* set default vallue for maxTrialNum and maxExecDuration

* fix nnictl

* support multiPhase (#127)

* fix nnictl bug

* support multiPhase

* Fix multiphase datastore problem (#125)

* Fix multiphase datastore problem

* updates

* updates

* updates

* updates

* Pull latest code (#2)

* webui logpath and document (#135)

* Add webui document and logpath as a href

* fix tslint

* fix comments by Chengmin

* Pai training service bug fix and enhancement (#136)

* Add NNI installation scripts

* Update pai script, update NNI_out_dir

* Update NNI dir in nni sdk local.py

* Create .nni folder in nni sdk local.py

* Add check before creating .nni folder

* Fix typo for PAI_INSTALL_NNI_SHELL_FORMAT

* Improve annotation (#138)

* Improve annotation

* Minor bugfix

* Selectively install through pip (#139)

Selectively install through pip 
* update setup.py

* fix paiTrainingService bugs (#137)

* fix nnictl bug

* add hdfs host validation

* fix bugs

* fix dockerfile

* fix install.sh

* update install.sh

* fix dockerfile

* Set timeout for HDFSUtility exists function

* remove unused TODO

* fix sdk

* add optional for outputDir and dataDir

* refactor dockerfile.base

* Remove unused import in hdfsclientUtility

* Add documentation for NNI PAI mode experiment (#141)

* Add documentation for NNI PAI mode

* Fix typo based on PR comments

* Exit with subprocess return code of trial keeper

* Remove additional exit code

* Fix typo based on PR comments

* update doc for smac tuner (#140)

* Revert "Selectively install through pip (#139)" due to potential pip install issue (#142)

* Revert "Selectively install through pip (#139)"

This reverts commit 1d174836.

* Add exit code of subprocess for trial_keeper

* Update README, add link to PAImode doc

* fix bug (#147)

* Refactor nnictl and add config_pai.yml (#144)

* fix nnictl bug

* add hdfs host validation

* fix bugs

* fix dockerfile

* fix install.sh

* update install.sh

* fix dockerfile

* Set timeout for HDFSUtility exists function

* remove unused TODO

* fix sdk

* add optional for outputDir and dataDir

* refactor dockerfile.base

* Remove unused import in hdfsclientUtility

* add config_pai.yml

* refactor nnictl create logic and add colorful print

* fix nnictl stop logic

* add annotation for config_pai.yml

* add document for start experiment

* fix config.yml

* fix document

* Fix trial keeper wrongly exit issue (#152)

* Fix trial keeper bug, use actual exitcode to exit rather than 1

* Fix bug of table sort (#145)

* Update doc for PAIMode and v0.2 release notes (#153)

* Update v0.2 documentation regards to release note and PAI training service

* Update document to describe NNI docker image

* Bug fix for SQuAD example tuner. (#134)

* Update Makefile (#151)

* test

* update setup.py

* update Makefile and install.sh

* rever setup.py

* change color

* update doc

* update doc

* fix auto-completion's extra space

* update Makefile

* update webui

* Update doc image (#163)

* update doc

* trivial

* trivial

* trivial

* trivial

* trivial

* trivial

* update image

* update image size

* Update ga squad (#104)

* update readme in ga_squad

* update readme

* fix typo

* Update README.md

* Update README.md

* Update README.md

* update readme

* sklearn examples (#169)

* fix nnictl bug

* fix install.sh

* add sklearn-regression example

* add sklearn classification

* update sklearn

* update example

* remove additional code

* Update batch tuner (#158)

* update readme in ga_squad

* update readme

* fix typo

* Update README.md

* Update README.md

* Update README.md

* update readme

* update batch tuner

* Quickly fix cascading search space bug in tuner (#156)

* update readme in ga_squad

* update readme

* fix typo

* Update README.md

* Update README.md

* Update README.md

* update readme

* quickly fix cascading searchspace bug in tuner

* Add iterative search space example (#119)

* update readme in ga_squad

* update readme

* fix typo

* Update README.md

* Update README.md

* Update README.md

* update readme

* add iterative search space example

* update

* update readme

* change name

* updates

* updates

* Updates CI

* updates
parent 6ef65117
...@@ -11,7 +11,7 @@ before_install: ...@@ -11,7 +11,7 @@ before_install:
- sudo sh -c 'PATH=/usr/local/node/bin:$PATH yarn global add serve' - sudo sh -c 'PATH=/usr/local/node/bin:$PATH yarn global add serve'
install: install:
- make - make
- make install - make easy-install
- export PATH=$HOME/.nni/bin:$PATH - export PATH=$HOME/.nni/bin:$PATH
before_script: before_script:
- cd test/naive - cd test/naive
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
import { ExperimentProfile, TrialJobStatistics } from './manager'; import { ExperimentProfile, TrialJobStatistics } from './manager';
import { TrialJobDetail, TrialJobStatus } from './trainingService'; import { TrialJobDetail, TrialJobStatus } from './trainingService';
type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED'; type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED' | 'ADD_HYPERPARAMETER';
type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM'; type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM' | 'REQUEST_PARAMETER';
interface ExperimentProfileRecord { interface ExperimentProfileRecord {
readonly timestamp: number; readonly timestamp: number;
...@@ -62,7 +62,7 @@ interface TrialJobInfo { ...@@ -62,7 +62,7 @@ interface TrialJobInfo {
status: TrialJobStatus; status: TrialJobStatus;
startTime?: number; startTime?: number;
endTime?: number; endTime?: number;
hyperParameters?: string; hyperParameters?: string[];
logPath?: string; logPath?: string;
finalMetricData?: string; finalMetricData?: string;
stderrPath?: string; stderrPath?: string;
......
...@@ -31,6 +31,7 @@ interface ExperimentParams { ...@@ -31,6 +31,7 @@ interface ExperimentParams {
maxExecDuration: number; //seconds maxExecDuration: number; //seconds
maxTrialNum: number; maxTrialNum: number;
searchSpace: string; searchSpace: string;
multiPhase?: boolean;
tuner: { tuner: {
className: string; className: string;
builtinTunerName?: string; builtinTunerName?: string;
......
...@@ -37,11 +37,16 @@ interface JobApplicationForm { ...@@ -37,11 +37,16 @@ interface JobApplicationForm {
readonly jobType: JobType; readonly jobType: JobType;
} }
interface HyperParameters {
readonly value: string;
readonly index: number;
}
/** /**
* define TrialJobApplicationForm * define TrialJobApplicationForm
*/ */
interface TrialJobApplicationForm extends JobApplicationForm { interface TrialJobApplicationForm extends JobApplicationForm {
readonly hyperParameters: string; readonly hyperParameters: HyperParameters;
} }
/** /**
...@@ -116,6 +121,6 @@ abstract class TrainingService { ...@@ -116,6 +121,6 @@ abstract class TrainingService {
export { export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm, TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType HostJobApplicationForm, JobApplicationForm, JobType
}; };
...@@ -158,8 +158,11 @@ function parseArg(names: string[]): string { ...@@ -158,8 +158,11 @@ function parseArg(names: string[]): string {
* @param assessor: similiar as tuner * @param assessor: similiar as tuner
* *
*/ */
function getMsgDispatcherCommand(tuner: any, assessor: any): string { function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean = false): string {
let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`; let command: string = `python3 -m nni --tuner_class_name ${tuner.className}`;
if (multiPhase) {
command += ' --multi_phase';
}
if (tuner.classArgs !== undefined) { if (tuner.classArgs !== undefined) {
command += ` --tuner_args ${JSON.stringify(JSON.stringify(tuner.classArgs))}`; command += ` --tuner_args ${JSON.stringify(JSON.stringify(tuner.classArgs))}`;
......
...@@ -27,6 +27,7 @@ const TRIAL_END = 'EN'; ...@@ -27,6 +27,7 @@ const TRIAL_END = 'EN';
const TERMINATE = 'TE'; const TERMINATE = 'TE';
const NEW_TRIAL_JOB = 'TR'; const NEW_TRIAL_JOB = 'TR';
const SEND_TRIAL_JOB_PARAMETER = 'SP';
const NO_MORE_TRIAL_JOBS = 'NO'; const NO_MORE_TRIAL_JOBS = 'NO';
const KILL_TRIAL_JOB = 'KI'; const KILL_TRIAL_JOB = 'KI';
...@@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([ ...@@ -39,6 +40,7 @@ const TUNER_COMMANDS: Set<string> = new Set([
TERMINATE, TERMINATE,
NEW_TRIAL_JOB, NEW_TRIAL_JOB,
SEND_TRIAL_JOB_PARAMETER,
NO_MORE_TRIAL_JOBS NO_MORE_TRIAL_JOBS
]); ]);
...@@ -63,5 +65,6 @@ export { ...@@ -63,5 +65,6 @@ export {
NO_MORE_TRIAL_JOBS, NO_MORE_TRIAL_JOBS,
KILL_TRIAL_JOB, KILL_TRIAL_JOB,
TUNER_COMMANDS, TUNER_COMMANDS,
ASSESSOR_COMMANDS ASSESSOR_COMMANDS,
SEND_TRIAL_JOB_PARAMETER
}; };
...@@ -26,6 +26,7 @@ import * as component from '../common/component'; ...@@ -26,6 +26,7 @@ import * as component from '../common/component';
import { Database, DataStore, MetricData, MetricDataRecord, MetricType, import { Database, DataStore, MetricData, MetricDataRecord, MetricType,
TrialJobEvent, TrialJobEventRecord, TrialJobInfo } from '../common/datastore'; TrialJobEvent, TrialJobEventRecord, TrialJobInfo } from '../common/datastore';
import { isNewExperiment } from '../common/experimentStartupInfo'; import { isNewExperiment } from '../common/experimentStartupInfo';
import { getExperimentId } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, TrialJobStatistics } from '../common/manager'; import { ExperimentProfile, TrialJobStatistics } from '../common/manager';
import { TrialJobStatus } from '../common/trainingService'; import { TrialJobStatus } from '../common/trainingService';
...@@ -35,6 +36,7 @@ class NNIDataStore implements DataStore { ...@@ -35,6 +36,7 @@ class NNIDataStore implements DataStore {
private db: Database = component.get(Database); private db: Database = component.get(Database);
private log: Logger = getLogger(); private log: Logger = getLogger();
private initTask!: Deferred<void>; private initTask!: Deferred<void>;
private multiPhase: boolean | undefined;
public init(): Promise<void> { public init(): Promise<void> {
if (this.initTask !== undefined) { if (this.initTask !== undefined) {
...@@ -112,13 +114,19 @@ class NNIDataStore implements DataStore { ...@@ -112,13 +114,19 @@ class NNIDataStore implements DataStore {
} }
public async getTrialJob(trialJobId: string): Promise<TrialJobInfo> { public async getTrialJob(trialJobId: string): Promise<TrialJobInfo> {
const trialJobs = await this.queryTrialJobs(undefined, trialJobId); const trialJobs: TrialJobInfo[] = await this.queryTrialJobs(undefined, trialJobId);
return trialJobs[0]; return trialJobs[0];
} }
public async storeMetricData(trialJobId: string, data: string): Promise<void> { public async storeMetricData(trialJobId: string, data: string): Promise<void> {
const metrics = JSON.parse(data) as MetricData; const metrics: MetricData = JSON.parse(data);
// REQUEST_PARAMETER is used to request new parameters for multiphase trial job,
// it is not metrics, so it is skipped here.
if (metrics.type === 'REQUEST_PARAMETER') {
return;
}
assert(trialJobId === metrics.trial_job_id); assert(trialJobId === metrics.trial_job_id);
await this.db.storeMetricData(trialJobId, JSON.stringify({ await this.db.storeMetricData(trialJobId, JSON.stringify({
trialJobId: metrics.trial_job_id, trialJobId: metrics.trial_job_id,
...@@ -160,25 +168,56 @@ class NNIDataStore implements DataStore { ...@@ -160,25 +168,56 @@ class NNIDataStore implements DataStore {
private async getFinalMetricData(trialJobId: string): Promise<any> { private async getFinalMetricData(trialJobId: string): Promise<any> {
const metrics: MetricDataRecord[] = await this.getMetricData(trialJobId, 'FINAL'); const metrics: MetricDataRecord[] = await this.getMetricData(trialJobId, 'FINAL');
if (metrics.length > 1) {
this.log.error(`Found multiple final results for trial job: ${trialJobId}`); const multiPhase: boolean = await this.isMultiPhase();
if (metrics.length > 1 && !multiPhase) {
this.log.error(`Found multiple FINAL results for trial job ${trialJobId}`);
}
return metrics[metrics.length - 1];
}
private async isMultiPhase(): Promise<boolean> {
if (this.multiPhase === undefined) {
this.multiPhase = (await this.getExperimentProfile(getExperimentId())).params.multiPhase;
} }
return metrics[0]; if (this.multiPhase !== undefined) {
return this.multiPhase;
} else {
return false;
}
} }
private getJobStatusByLatestEvent(event: TrialJobEvent): TrialJobStatus { private getJobStatusByLatestEvent(oldStatus: TrialJobStatus, event: TrialJobEvent): TrialJobStatus {
switch (event) { switch (event) {
case 'USER_TO_CANCEL': case 'USER_TO_CANCEL':
return 'USER_CANCELED'; return 'USER_CANCELED';
case 'ADD_CUSTOMIZED': case 'ADD_CUSTOMIZED':
return 'WAITING'; return 'WAITING';
case 'ADD_HYPERPARAMETER':
return oldStatus;
default: default:
} }
return <TrialJobStatus>event; return <TrialJobStatus>event;
} }
private mergeHyperParameters(hyperParamList: string[], newParamStr: string): string[] {
const mergedHyperParams: any[] = [];
const newParam: any = JSON.parse(newParamStr);
for (const hyperParamStr of hyperParamList) {
const hyperParam: any = JSON.parse(hyperParamStr);
mergedHyperParams.push(hyperParam);
}
if (mergedHyperParams.filter((value: any) => value.parameter_index === newParam.parameter_index).length <= 0) {
mergedHyperParams.push(newParam);
}
return mergedHyperParams.map<string>((value: any) => { return JSON.stringify(value); });
}
private getTrialJobsByReplayEvents(trialJobEvents: TrialJobEventRecord[]): Map<string, TrialJobInfo> { private getTrialJobsByReplayEvents(trialJobEvents: TrialJobEventRecord[]): Map<string, TrialJobInfo> {
const map: Map<string, TrialJobInfo> = new Map(); const map: Map<string, TrialJobInfo> = new Map();
// assume data is stored by time ASC order // assume data is stored by time ASC order
...@@ -192,7 +231,8 @@ class NNIDataStore implements DataStore { ...@@ -192,7 +231,8 @@ class NNIDataStore implements DataStore {
} else { } else {
jobInfo = { jobInfo = {
id: record.trialJobId, id: record.trialJobId,
status: this.getJobStatusByLatestEvent(record.event) status: this.getJobStatusByLatestEvent('UNKNOWN', record.event),
hyperParameters: []
}; };
} }
if (!jobInfo) { if (!jobInfo) {
...@@ -221,9 +261,13 @@ class NNIDataStore implements DataStore { ...@@ -221,9 +261,13 @@ class NNIDataStore implements DataStore {
} }
default: default:
} }
jobInfo.status = this.getJobStatusByLatestEvent(record.event); jobInfo.status = this.getJobStatusByLatestEvent(jobInfo.status, record.event);
if (record.data !== undefined && record.data.trim().length > 0) { if (record.data !== undefined && record.data.trim().length > 0) {
jobInfo.hyperParameters = record.data; if (jobInfo.hyperParameters !== undefined) {
jobInfo.hyperParameters = this.mergeHyperParameters(jobInfo.hyperParameters, record.data);
} else {
assert(false, 'jobInfo.hyperParameters is undefined');
}
} }
map.set(record.trialJobId, jobInfo); map.set(record.trialJobId, jobInfo);
} }
......
...@@ -37,7 +37,7 @@ import { ...@@ -37,7 +37,7 @@ import {
import { delay , getLogDir, getMsgDispatcherCommand} from '../common/utils'; import { delay , getLogDir, getMsgDispatcherCommand} from '../common/utils';
import { import {
ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA, ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA,
REQUEST_TRIAL_JOBS, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
} from './commands'; } from './commands';
import { createDispatcherInterface, IpcInterface } from './ipcInterface'; import { createDispatcherInterface, IpcInterface } from './ipcInterface';
import { TrialJobMaintainerEvent, TrialJobs } from './trialJobs'; import { TrialJobMaintainerEvent, TrialJobs } from './trialJobs';
...@@ -116,7 +116,7 @@ class NNIManager implements Manager { ...@@ -116,7 +116,7 @@ class NNIManager implements Manager {
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.log.debug('Setup tuner...'); this.log.debug('Setup tuner...');
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor); const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`); console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner( this.setupTuner(
//expParams.tuner.tunerCommand, //expParams.tuner.tunerCommand,
...@@ -140,7 +140,7 @@ class NNIManager implements Manager { ...@@ -140,7 +140,7 @@ class NNIManager implements Manager {
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId); this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
const expParams: ExperimentParams = this.experimentProfile.params; const expParams: ExperimentParams = this.experimentProfile.params;
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor); const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`); console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner( this.setupTuner(
dispatcherCommand, dispatcherCommand,
...@@ -462,7 +462,10 @@ class NNIManager implements Manager { ...@@ -462,7 +462,10 @@ class NNIManager implements Manager {
this.currSubmittedTrialNum++; this.currSubmittedTrialNum++;
const trialJobAppForm: TrialJobApplicationForm = { const trialJobAppForm: TrialJobApplicationForm = {
jobType: 'TRIAL', jobType: 'TRIAL',
hyperParameters: content hyperParameters: {
value: content,
index: 0
}
}; };
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm); const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm);
this.trialJobsMaintainer.setTrialJob(trialJobDetail.id, Object.assign({}, trialJobDetail)); this.trialJobsMaintainer.setTrialJob(trialJobDetail.id, Object.assign({}, trialJobDetail));
...@@ -474,6 +477,22 @@ class NNIManager implements Manager { ...@@ -474,6 +477,22 @@ class NNIManager implements Manager {
} }
} }
break; break;
case SEND_TRIAL_JOB_PARAMETER:
const tunerCommand: any = JSON.parse(content);
assert(tunerCommand.parameter_index >= 0);
assert(tunerCommand.trial_job_id !== undefined);
const trialJobForm: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: {
value: content,
index: tunerCommand.parameter_index
}
};
await this.trainingService.updateTrialJob(tunerCommand.trial_job_id, trialJobForm);
await this.dataStore.storeTrialJobEvent(
'ADD_HYPERPARAMETER', tunerCommand.trial_job_id, content, undefined);
break;
case NO_MORE_TRIAL_JOBS: case NO_MORE_TRIAL_JOBS:
this.trialJobsMaintainer.setNoMoreTrials(); this.trialJobsMaintainer.setNoMoreTrials();
break; break;
......
...@@ -57,6 +57,7 @@ export namespace ValidationSchemas { ...@@ -57,6 +57,7 @@ export namespace ValidationSchemas {
trialConcurrency: joi.number().min(0).required(), trialConcurrency: joi.number().min(0).required(),
searchSpace: joi.string().required(), searchSpace: joi.string().required(),
maxExecDuration: joi.number().min(0).required(), maxExecDuration: joi.number().min(0).required(),
multiPhase: joi.boolean(),
tuner: joi.object({ tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner'), builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner'),
codeDir: joi.string(), codeDir: joi.string(),
......
...@@ -30,10 +30,11 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -30,10 +30,11 @@ import { getLogger, Logger } from '../../common/log';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { import {
HostJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm, HostJobApplicationForm, JobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils'; import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { file } from 'tmp';
const tkill = require('tree-kill'); const tkill = require('tree-kill');
...@@ -210,8 +211,18 @@ class LocalTrainingService implements TrainingService { ...@@ -210,8 +211,18 @@ class LocalTrainingService implements TrainingService {
* @param trialJobId trial job id * @param trialJobId trial job id
* @param form job application form * @param form job application form
*/ */
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> { public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError(); const trialJobDetail: undefined | TrialJobDetail = this.jobMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail;
} }
/** /**
...@@ -332,10 +343,7 @@ class LocalTrainingService implements TrainingService { ...@@ -332,10 +343,7 @@ class LocalTrainingService implements TrainingService {
await cpp.exec(`mkdir -p ${path.join(trialJobDetail.workingDirectory, '.nni')}`); await cpp.exec(`mkdir -p ${path.join(trialJobDetail.workingDirectory, '.nni')}`);
await cpp.exec(`touch ${path.join(trialJobDetail.workingDirectory, '.nni', 'metrics')}`); await cpp.exec(`touch ${path.join(trialJobDetail.workingDirectory, '.nni', 'metrics')}`);
await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, 'run.sh'), runScriptLines.join('\n'), { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialJobDetail.workingDirectory, 'run.sh'), runScriptLines.join('\n'), { encoding: 'utf8' });
await fs.promises.writeFile( await this.writeParameterFile(trialJobDetail.workingDirectory, (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters);
path.join(trialJobDetail.workingDirectory, 'parameter.cfg'),
(<TrialJobApplicationForm>trialJobDetail.form).hyperParameters,
{ encoding: 'utf8' });
const process: cp.ChildProcess = cp.exec(`bash ${path.join(trialJobDetail.workingDirectory, 'run.sh')}`); const process: cp.ChildProcess = cp.exec(`bash ${path.join(trialJobDetail.workingDirectory, 'run.sh')}`);
this.setTrialJobStatus(trialJobDetail, 'RUNNING'); this.setTrialJobStatus(trialJobDetail, 'RUNNING');
...@@ -402,6 +410,11 @@ class LocalTrainingService implements TrainingService { ...@@ -402,6 +410,11 @@ class LocalTrainingService implements TrainingService {
} }
} }
} }
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, `parameter_${hyperParameters.index}.cfg`);
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
} }
export { LocalTrainingService }; export { LocalTrainingService };
...@@ -34,7 +34,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo'; ...@@ -34,7 +34,7 @@ import { getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { ObservableTimer } from '../../common/observableTimer'; import { ObservableTimer } from '../../common/observableTimer';
import { import {
HostJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils'; import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { GPUSummary } from '../common/gpuData'; import { GPUSummary } from '../common/gpuData';
...@@ -198,8 +198,24 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -198,8 +198,24 @@ class RemoteMachineTrainingService implements TrainingService {
* @param trialJobId trial job id * @param trialJobId trial job id
* @param form job application form * @param form job application form
*/ */
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> { public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError(); this.log.info(`updateTrialJob: form: ${JSON.stringify(form)}`);
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
const rmMeta: RemoteMachineMeta | undefined = (<RemoteMachineTrialJobDetail>trialJobDetail).rmMeta;
if (rmMeta !== undefined) {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters, rmMeta);
} else {
throw new Error(`updateTrialJob failed: ${trialJobId} rmMeta not found`);
}
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail;
} }
/** /**
...@@ -442,15 +458,13 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -442,15 +458,13 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally. //create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${trialLocalTempFolder}`); await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
// Write file content ( run.sh and parameter.cfg ) to local tmp files // Write file content ( run.sh and parameter_0.cfg ) to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptContent, { encoding: 'utf8' });
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'parameter.cfg'), form.hyperParameters, { encoding: 'utf8' });
// Copy local tmp files to remote machine // Copy local tmp files to remote machine
await SSHClientUtility.copyFileToRemote( await SSHClientUtility.copyFileToRemote(
path.join(trialLocalTempFolder, 'run.sh'), path.join(trialWorkingFolder, 'run.sh'), sshClient); path.join(trialLocalTempFolder, 'run.sh'), path.join(trialWorkingFolder, 'run.sh'), sshClient);
await SSHClientUtility.copyFileToRemote( await this.writeParameterFile(trialJobId, form.hyperParameters, rmScheduleInfo.rmMeta);
path.join(trialLocalTempFolder, 'parameter.cfg'), path.join(trialWorkingFolder, 'parameter.cfg'), sshClient);
// Copy files in codeDir to remote working directory // Copy files in codeDir to remote working directory
await SSHClientUtility.copyDirectoryToRemote(this.trialConfig.codeDir, trialWorkingFolder, sshClient); await SSHClientUtility.copyDirectoryToRemote(this.trialConfig.codeDir, trialWorkingFolder, sshClient);
...@@ -562,6 +576,22 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -562,6 +576,22 @@ class RemoteMachineTrainingService implements TrainingService {
return jobpidPath; return jobpidPath;
} }
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters, rmMeta: RemoteMachineMeta): Promise<void> {
const sshClient: Client | undefined = this.machineSSHClientMap.get(rmMeta);
if (sshClient === undefined) {
throw new Error('sshClient is undefined.');
}
const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId);
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
const fileName: string = `parameter_${hyperParameters.index}.cfg`;
const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
await SSHClientUtility.copyFileToRemote(localFilepath, path.join(trialWorkingFolder, fileName), sshClient);
}
} }
export { RemoteMachineTrainingService }; export { RemoteMachineTrainingService };
...@@ -100,7 +100,10 @@ describe('Unit Test for RemoteMachineTrainingService', () => { ...@@ -100,7 +100,10 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`); TrialConfigMetadataKey.TRIAL_CONFIG, `{"command":"sleep 1h && echo ","codeDir":"${localCodeDir}","gpuNum":1}`);
const form: TrialJobApplicationForm = { const form: TrialJobApplicationForm = {
jobType: 'TRIAL', jobType: 'TRIAL',
hyperParameters: 'mock hyperparameters' hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
}; };
const trialJob = await remoteMachineTrainingService.submitTrialJob(form); const trialJob = await remoteMachineTrainingService.submitTrialJob(form);
...@@ -135,7 +138,10 @@ describe('Unit Test for RemoteMachineTrainingService', () => { ...@@ -135,7 +138,10 @@ describe('Unit Test for RemoteMachineTrainingService', () => {
// submit job // submit job
const form: TrialJobApplicationForm = { const form: TrialJobApplicationForm = {
jobType: 'TRIAL', jobType: 'TRIAL',
hyperParameters: 'mock hyperparameters' hyperParameters: {
value: 'mock hyperparameters',
index: 0
}
}; };
const jobDetail: TrialJobDetail = await remoteMachineTrainingService.submitTrialJob(form); const jobDetail: TrialJobDetail = await remoteMachineTrainingService.submitTrialJob(form);
// Add metrics listeners // Add metrics listeners
......
...@@ -29,7 +29,7 @@ import importlib ...@@ -29,7 +29,7 @@ import importlib
from .constants import ModuleName, ClassName, ClassArgs from .constants import ModuleName, ClassName, ClassArgs
from nni.msg_dispatcher import MsgDispatcher from nni.msg_dispatcher import MsgDispatcher
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
logger = logging.getLogger('nni.main') logger = logging.getLogger('nni.main')
logger.debug('START') logger.debug('START')
...@@ -90,6 +90,7 @@ def parse_args(): ...@@ -90,6 +90,7 @@ def parse_args():
help='Assessor directory') help='Assessor directory')
parser.add_argument('--assessor_class_filename', type=str, required=False, parser.add_argument('--assessor_class_filename', type=str, required=False,
help='Assessor class file path') help='Assessor class file path')
parser.add_argument('--multi_phase', action='store_true')
flags, _ = parser.parse_known_args() flags, _ = parser.parse_known_args()
return flags return flags
...@@ -132,7 +133,10 @@ def main(): ...@@ -132,7 +133,10 @@ def main():
if assessor is None: if assessor is None:
raise AssertionError('Failed to create Assessor instance') raise AssertionError('Failed to create Assessor instance')
dispatcher = MsgDispatcher(tuner, assessor) if args.multi_phase:
dispatcher = MultiPhaseMsgDispatcher(tuner, assessor)
else:
dispatcher = MsgDispatcher(tuner, assessor)
try: try:
dispatcher.run() dispatcher.run()
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from collections import defaultdict
import json_tricks
from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.assessor import AssessResult
_logger = logging.getLogger(__name__)
# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = [ ]
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
_logger.debug('Assessor is not configured')
def load_checkpoint(self):
self.tuner.load_checkpoint()
if self.assessor is not None:
self.assessor.load_checkpoint()
def save_checkpoint(self):
self.tuner.save_checkpoint()
if self.assessor is not None:
self.assessor.save_checkpoint()
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
params_list = self.tuner.generate_multiple_parameters(ids)
assert len(ids) == len(params_list)
for i, _ in enumerate(ids):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
return True
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
return True
def handle_add_customized_trial(self, data):
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
return True
def handle_report_metric_data(self, data):
trial_job_id = data['trial_job_id']
if data['type'] == 'FINAL':
id_ = data['parameter_id']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
elif data['type'] == 'REQUEST_PARAMETER':
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id)
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
return True
def handle_trial_end(self, data):
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
return True
def _handle_intermediate_metric_data(self, data):
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
return True
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return True
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return True
try:
result = self.assessor.assess_trial(trial_job_id, ordered_history)
except Exception as e:
_logger.exception('Assessor error')
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
else:
_logger.debug('GOOD')
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from nni.recoverable import Recoverable
_logger = logging.getLogger(__name__)
class MultiPhaseTuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int
"""
return [self.generate_parameters(parameter_id) for parameter_id in parameter_id_list]
def receive_trial_result(self, parameter_id, parameters, reward, trial_job_id):
"""Invoked when a trial reports its final result. Must override.
parameter_id: int
parameters: object created by 'generate_parameters()'
reward: object reported by trial
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, reward, trial_job_id):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int
parameters: object created by user
reward: object reported by trial
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def update_search_space(self, search_space):
"""Update the search space of tuner. Must override.
search_space: JSON object
"""
raise NotImplementedError('Tuner: update_search_space not implemented')
def load_checkpoint(self):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def save_checkpoint(self):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def _on_exit(self):
pass
def _on_error(self):
pass
...@@ -18,11 +18,12 @@ ...@@ -18,11 +18,12 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ================================================================================================== # ==================================================================================================
import json
import os import os
import json
import time
import json_tricks
from ..common import init_logger from ..common import init_logger, env_args
_sysdir = os.environ['NNI_SYS_DIR'] _sysdir = os.environ['NNI_SYS_DIR']
if not os.path.exists(os.path.join(_sysdir, '.nni')): if not os.path.exists(os.path.join(_sysdir, '.nni')):
...@@ -35,10 +36,28 @@ if not os.path.exists(_outputdir): ...@@ -35,10 +36,28 @@ if not os.path.exists(_outputdir):
_log_file_path = os.path.join(_outputdir, 'trial.log') _log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path) init_logger(_log_file_path)
_param_index = 0
def request_next_parameter():
metric = json_tricks.dumps({
'trial_job_id': env_args.trial_job_id,
'type': 'REQUEST_PARAMETER',
'sequence': 0,
'parameter_index': _param_index
})
send_metric(metric)
def get_parameters(): def get_parameters():
params_file = open(os.path.join(_sysdir, 'parameter.cfg'), 'r') global _param_index
return json.load(params_file) params_filepath = os.path.join(_sysdir, 'parameter_{}.cfg'.format(_param_index))
if not os.path.isfile(params_filepath):
request_next_parameter()
while not os.path.isfile(params_filepath):
time.sleep(3)
params_file = open(params_filepath, 'r')
params = json.load(params_file)
_param_index += 1
return params
def send_metric(string): def send_metric(string):
data = (string + '\n').encode('utf8') data = (string + '\n').encode('utf8')
......
...@@ -34,6 +34,7 @@ class CommandType(Enum): ...@@ -34,6 +34,7 @@ class CommandType(Enum):
# out # out
NewTrialJob = b'TR' NewTrialJob = b'TR'
SendTrialJobParameter = b'SP'
NoMoreTrialJobs = b'NO' NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI' KillTrialJob = b'KI'
...@@ -55,7 +56,7 @@ def send(command, data): ...@@ -55,7 +56,7 @@ def send(command, data):
data = data.encode('utf8') data = data.encode('utf8')
assert len(data) < 1000000, 'Command too long' assert len(data) < 1000000, 'Command too long'
msg = b'%b%06d%b' % (command.value, len(data), data) msg = b'%b%06d%b' % (command.value, len(data), data)
logging.getLogger(__name__).debug('Sending command, data: [%s]' % data) logging.getLogger(__name__).debug('Sending command, data: [%s]' % msg)
_out_file.write(msg) _out_file.write(msg)
_out_file.flush() _out_file.flush()
......
...@@ -32,11 +32,13 @@ __all__ = [ ...@@ -32,11 +32,13 @@ __all__ = [
] ]
_params = platform.get_parameters() _params = None
def get_parameters(): def get_parameters():
"""Returns a set of (hyper-)paremeters generated by Tuner.""" """Returns a set of (hyper-)paremeters generated by Tuner."""
global _params
_params = platform.get_parameters()
return _params['parameters'] return _params['parameters']
...@@ -51,6 +53,7 @@ def report_intermediate_result(metric): ...@@ -51,6 +53,7 @@ def report_intermediate_result(metric):
metric: serializable object. metric: serializable object.
""" """
global _intermediate_seq global _intermediate_seq
assert _params is not None, 'nni.get_parameters() needs to be called before report_intermediate_result'
metric = json_tricks.dumps({ metric = json_tricks.dumps({
'parameter_id': _params['parameter_id'], 'parameter_id': _params['parameter_id'],
'trial_job_id': env_args.trial_job_id, 'trial_job_id': env_args.trial_job_id,
...@@ -66,6 +69,7 @@ def report_final_result(metric): ...@@ -66,6 +69,7 @@ def report_final_result(metric):
"""Reports final result to tuner. """Reports final result to tuner.
metric: serializable object. metric: serializable object.
""" """
assert _params is not None, 'nni.get_parameters() needs to be called before report_final_result'
metric = json_tricks.dumps({ metric = json_tricks.dumps({
'parameter_id': _params['parameter_id'], 'parameter_id': _params['parameter_id'],
'trial_job_id': env_args.trial_job_id, 'trial_job_id': env_args.trial_job_id,
......
import logging
import random
from io import BytesIO
import nni
import nni.protocol
from nni.protocol import CommandType, send, receive
from nni.multi_phase.multi_phase_tuner import MultiPhaseTuner
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
from unittest import TestCase, main
class NaiveMultiPhaseTuner(MultiPhaseTuner):
'''
supports only choices
'''
def __init__(self):
self.search_space = None
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
generated_parameters = {}
if self.search_space is None:
raise AssertionError('Search space not specified')
for k in self.search_space:
param = self.search_space[k]
if not param['_type'] == 'choice':
raise ValueError('Only choice type is supported')
param_values = param['_value']
generated_parameters[k] = param_values[random.randint(0, len(param_values)-1)]
logging.getLogger(__name__).debug(generated_parameters)
return generated_parameters
def receive_trial_result(self, parameter_id, parameters, reward, trial_job_id):
logging.getLogger(__name__).debug('receive_trial_result: {},{},{},{}'.format(parameter_id, parameters, reward, trial_job_id))
def receive_customized_trial_result(self, parameter_id, parameters, reward, trial_job_id):
pass
def update_search_space(self, search_space):
self.search_space = search_space
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._in_file = _in_buf
nni.protocol._out_file = _out_buf
def _test_tuner():
_reverse_io() # now we are sending to Tuner's incoming stream
send(CommandType.UpdateSearchSpace, "{\"learning_rate\": {\"_value\": [0.0001, 0.001, 0.002, 0.005, 0.01], \"_type\": \"choice\"}, \"optimizer\": {\"_value\": [\"Adam\", \"SGD\"], \"_type\": \"choice\"}}")
send(CommandType.RequestTrialJobs, '2')
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10,"trial_job_id":"abc"}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11,"trial_job_id":"abc"}')
send(CommandType.AddCustomizedTrialJob, '{"param":-1}')
send(CommandType.ReportMetricData, '{"parameter_id":2,"type":"FINAL","value":22,"trial_job_id":"abc"}')
send(CommandType.RequestTrialJobs, '1')
send(CommandType.TrialEnd, '{"trial_job_id":"abc"}')
_restore_io()
tuner = NaiveMultiPhaseTuner()
dispatcher = MultiPhaseMsgDispatcher(tuner)
dispatcher.run()
_reverse_io() # now we are receiving from Tuner's outgoing stream
command, data = receive() # this one is customized
print(command, data)
class MultiPhaseTestCase(TestCase):
def test_tuner(self):
_test_tuner()
if __name__ == '__main__':
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment