"server/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "20c5fd39c8b275c0c7d7e7be8ce03d48aa32c64e"
Unverified Commit f4ee9f8a authored by fishyds's avatar fishyds Committed by GitHub
Browse files

Fix OpenPAI training service failed issue after multiphase training code merged (#206)

* fix parameter file name issue for multi-phase training

* Updated based on comments
parent 06413f1a
......@@ -2,10 +2,10 @@
*Assessor receive intermediate result from Trial and decide whether the Trial should be killed. Once the Trial experiment meets the early stop conditions, the assessor will kill the Trial.*
So, if user want to implement a customized Assessor, she/he only need to:
So, if users want to implement a customized Assessor, they only need to:
**1) Inherit a tuner of a base Tuner class**
**1) Inherit an assessor of a base Assessor class**
```python
from nni.assessor import Assessor
......@@ -31,7 +31,7 @@ class CustomizedAssessor(Assessor):
# you code implement here.
...
```
**3) Write a script to run Tuner**
**3) Write a script to run Assessor**
```python
import argparse
......
......@@ -31,7 +31,7 @@ import * as util from 'util';
import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentId, setExperimentStartupInfo } from './experimentStartupInfo';
import { Manager } from './manager';
import { TrainingService } from './trainingService';
import { HyperParameters, TrainingService } from './trainingService';
function getExperimentRootDir(): string {
return path.join(os.homedir(), 'nni', 'experiments', getExperimentId());
......@@ -194,6 +194,23 @@ function getMsgDispatcherCommand(tuner: any, assessor: any, multiPhase: boolean
return command;
}
/**
* Generate parameter file name based on HyperParameters object
* @param hyperParameters HyperParameters instance
*/
function generateParamFileName(hyperParameters : HyperParameters): string {
assert(hyperParameters !== undefined);
assert(hyperParameters.index >= 0);
let paramFileName : string;
if(hyperParameters.index == 0) {
paramFileName = 'parameter.cfg';
} else {
paramFileName = `parameter_${hyperParameters.index}.cfg`
}
return paramFileName;
}
/**
* Initialize a pseudo experiment environment for unit test.
* Must be paired with `cleanupUnitTest()`.
......@@ -242,5 +259,5 @@ function getIPV4Address(): string {
return ipv4Address;
}
export { getMsgDispatcherCommand, getLogDir, getExperimentRootDir, getDefaultDatabaseDir, getIPV4Address,
mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect };
export { generateParamFileName, getMsgDispatcherCommand, getLogDir, getExperimentRootDir,
getDefaultDatabaseDir, getIPV4Address, mkDirP, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomSelect };
......@@ -33,7 +33,7 @@ import {
HostJobApplicationForm, JobApplicationForm, HyperParameters, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../common/utils';
import { file } from 'tmp';
const tkill = require('tree-kill');
......@@ -412,7 +412,7 @@ class LocalTrainingService implements TrainingService {
}
private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
const filepath: string = path.join(directory, `parameter_${hyperParameters.index}.cfg`);
const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
}
}
......
......@@ -20,6 +20,7 @@
'use strict'
import * as assert from 'assert';
import * as component from '../../common/component';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
......@@ -37,7 +38,7 @@ import {
JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { PAIJobRestServer } from './paiJobRestServer'
import { PAITrialJobDetail, PAI_INSTALL_NNI_SHELL_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_LOG_PATH_FORMAT } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector';
......@@ -156,11 +157,12 @@ class PAITrainingService implements TrainingService {
const runScriptContent : string = PAI_INSTALL_NNI_SHELL_FORMAT;
// Write NNI installation file to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });
// Write file content ( parameter.cfg ) to local tmp folders
const trialForm : TrialJobApplicationForm = (<TrialJobApplicationForm>form)
if(trialForm) {
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'parameter.cfg'), trialForm.hyperParameters, { encoding: 'utf8' });
await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(trialForm.hyperParameters)),
trialForm.hyperParameters.value, { encoding: 'utf8' });
}
// Step 1. Prepare PAI job configuration
......
......@@ -36,7 +36,7 @@ import { ObservableTimer } from '../../common/observableTimer';
import {
HostJobApplicationForm, HyperParameters, JobApplicationForm, TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, getExperimentRootDir, uniqueString } from '../../common/utils';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../common/utils';
import { GPUSummary } from '../common/gpuData';
import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
......@@ -458,7 +458,7 @@ class RemoteMachineTrainingService implements TrainingService {
//create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
// Write file content ( run.sh and parameter_0.cfg ) to local tmp files
// Write file content ( run.sh and parameter.cfg ) to local tmp files
await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run.sh'), runScriptContent, { encoding: 'utf8' });
// Copy local tmp files to remote machine
......@@ -586,7 +586,7 @@ class RemoteMachineTrainingService implements TrainingService {
const trialWorkingFolder: string = path.join(this.remoteExpRootDir, 'trials', trialJobId);
const trialLocalTempFolder: string = path.join(this.expRootDir, 'trials-local', trialJobId);
const fileName: string = `parameter_${hyperParameters.index}.cfg`;
const fileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, fileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
......
......@@ -49,7 +49,7 @@ def request_next_parameter():
def get_parameters():
global _param_index
params_filepath = os.path.join(_sysdir, 'parameter_{}.cfg'.format(_param_index))
params_filepath = os.path.join(_sysdir, ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0])
if not os.path.isfile(params_filepath):
request_next_parameter()
while not os.path.isfile(params_filepath):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment