Unverified Commit f4ee9f8a authored by fishyds's avatar fishyds Committed by GitHub
Browse files

Fix OpenPAI training service failed issue after multiphase training code merged (#206)

* fix parameter file name issue for multi-phase training

* Updated based on comments
parent 06413f1a
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
*Assessor receive intermediate result from Trial and decide whether the Trial should be killed. Once the Trial experiment meets the early stop conditions, the assessor will kill the Trial.* *Assessor receive intermediate result from Trial and decide whether the Trial should be killed. Once the Trial experiment meets the early stop conditions, the assessor will kill the Trial.*
So, if user want to implement a customized Assessor, she/he only need to: So, if users want to implement a customized Assessor, they only need to:
**1) Inherit a tuner of a base Tuner class** **1) Inherit an assessor of a base Assessor class**
```python ```python
from nni.assessor import Assessor from nni.assessor import Assessor
...@@ -31,7 +31,7 @@ class CustomizedAssessor(Assessor): ...@@ -31,7 +31,7 @@ class CustomizedAssessor(Assessor):
# you code implement here. # you code implement here.
... ...
``` ```
**3) Write a script to run Tuner** **3) Write a script to run Assessor**
```python ```python
import argparse import argparse
......
...@@ -31,7 +31,7 @@ import * as util from 'util'; ...@@ -31,7 +31,7 @@ import * as util from 'util';
import { Database, DataStore } from './datastore'; import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentId, setExperimentStartupInfo } from './experimentStartupInfo'; import { ExperimentStartupInfo, getExperimentId, setExperimentStartupInfo } from './experimentStartupInfo';
import { Manager } from './manager'; import { Manager } from './manager';
import { TrainingService } from './trainingService'; import { HyperParameters, TrainingService } from './trainingService';
function getExperimentRootDir(): string { function getExperimentRootDir(): string {
return path.join(os.homedir(), 'nni', 'experiments', getExperimentId()); return path.join(os.homedir(), 'nni', 'experiments', getExperimentId());
...@@ -194,6 +194,23 @@ function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean ...@@ -194,6 +194,23 @@ function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean
return command; return command;
} }
/**
* Generate parameter file name based on HyperParameters object
* @param hyperParameters HyperParameters instance
*/
function generateParamFileName(hyperParameters : HyperParameters): string {
assert(hyperParameters !== undefined);
assert(hyperParameters.index >= 0);
let paramFileName : string;
if(hyperParameters.index == 0) {
paramFileName = 'parameter.cfg';
} else {
paramFileName = `parameter_${hyperParameters.index}.cfg`
}
return paramFileName;
}
/** /**
* Initialize a pseudo experiment environment for unit test. * Initialize a pseudo experiment environment for unit test.
* Must be paired with `cleanupUnitTest()`. * Must be paired with `cleanupUnitTest()`.
...@@ -242,5 +259,5 @@ function getIPV4Address(): string { ...@@ -242,5 +259,5 @@ function getIPV4Address(): string {
return ipv4Address; return ipv4Address;
} }
export { getMsgDispatcherCommand, getLogDir, getExperimentRootDir, getDefaultDatabaseDir, getIPV4Address, export { generateParamFileName, getMsgDispatcherCommand, getLogDir, getExperimentRootDir,
mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect }; getDefaultDatabaseDir, getIPV4Address, mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect };
...@@ -33,7 +33,7 @@ import { ...@@ -33,7 +33,7 @@ import {
HostJobApplicationForm, JobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm, HostJobApplicationForm, JobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../common/utils';
import { file } from 'tmp'; import { file } from 'tmp';
const tkill = require('tree-kill'); const tkill = require('tree-kill');
...@@ -412,7 +412,7 @@ class LocalTrainingService implements TrainingService { ...@@ -412,7 +412,7 @@ class LocalTrainingService implements TrainingService {
} }
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> { private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, `parameter_${hyperParameters.index}.cfg`); const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' }); await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
} }
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
'use strict' 'use strict'
import * as assert from 'assert';
import * as component from '../../common/component'; import * as component from '../../common/component';
import * as cpp from 'child-process-promise'; import * as cpp from 'child-process-promise';
import * as fs from 'fs'; import * as fs from 'fs';
...@@ -37,7 +38,7 @@ import { ...@@ -37,7 +38,7 @@ import {
JobApplicationForm, TrainingService, TrialJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { PAIJobRestServer } from './paiJobRestServer' import { PAIJobRestServer } from './paiJobRestServer'
import { PAITrialJobDetail, PAI_INSTALL_NNI_SHELL_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_LOG_PATH_FORMAT } from './paiData'; import { PAITrialJobDetail, PAI_INSTALL_NNI_SHELL_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_LOG_PATH_FORMAT } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector'; import { PAIJobInfoCollector } from './paiJobInfoCollector';
...@@ -156,11 +157,12 @@ class PAITrainingService implements TrainingService { ...@@ -156,11 +157,12 @@ class PAITrainingService implements TrainingService {
const runScriptContent : string = PAI_INSTALL_NNI_SHELL_FORMAT; const runScriptContent : string = PAI_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files // Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local tmp folders // Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>form) const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>form)
if(trialForm) { if(trialForm) {
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'parameter.cfg'), trialForm.hyperParameters, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
trialForm.hyperParameters.value, { encoding: 'utf8' });
} }
// Step 1. Prepare PAI job configuration // Step 1. Prepare PAI job configuration
......
...@@ -36,7 +36,7 @@ import { ObservableTimer } from '../../common/observableTimer'; ...@@ -36,7 +36,7 @@ import { ObservableTimer } from '../../common/observableTimer';
import { import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../common/utils';
import { GPUSummary } from '../common/gpuData'; import { GPUSummary } from '../common/gpuData';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
...@@ -458,7 +458,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -458,7 +458,7 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally. //create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${trialLocalTempFolder}`); await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
// Write file content ( run.sh and parameter_0.cfg ) to local tmp files // Write file content ( run.sh and parameter.cfg ) to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptContent, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptContent, { encoding: 'utf8' });
// Copy local tmp files to remote machine // Copy local tmp files to remote machine
...@@ -586,7 +586,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -586,7 +586,7 @@ class RemoteMachineTrainingService implements TrainingService {
const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId); const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId);
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
const fileName: string = `parameter_${hyperParameters.index}.cfg`; const fileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, fileName); const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' }); await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
......
...@@ -49,7 +49,7 @@ def request_next_parameter(): ...@@ -49,7 +49,7 @@ def request_next_parameter():
def get_parameters(): def get_parameters():
global _param_index global _param_index
params_filepath = os.path.join(_sysdir, 'parameter_{}.cfg'.format(_param_index)) params_filepath = os.path.join(_sysdir, ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0])
if not os.path.isfile(params_filepath): if not os.path.isfile(params_filepath):
request_next_parameter() request_next_parameter()
while not os.path.isfile(params_filepath): while not os.path.isfile(params_filepath):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment