Commit d6febf29 authored by suiguoxin's avatar suiguoxin
Browse files

Merge branch 'master' of git://github.com/microsoft/nni

parents 77c95479 c2179921
......@@ -49,12 +49,12 @@ class RandomNASTuner(Tuner):
self.searchspace_json = search_space
self.random_state = np.random.RandomState()
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
'''generate
'''
return random_archi_generator(self.searchspace_json, self.random_state)
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''receive
'''
pass
......@@ -112,7 +112,7 @@ class CustomerTuner(Tuner):
population.append(Individual(indiv_id=self.generate_new_id(), graph_cfg=graph_tmp, result=None))
return population
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial graph config, as a serializable object.
An example configuration:
```json
......@@ -196,7 +196,7 @@ class CustomerTuner(Tuner):
logger.debug("trial {} ready".format(indiv.indiv_id))
return param_json
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''
Record an observation of the objective function
parameter_id : int
......
......@@ -375,7 +375,7 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number):
}
function validateFileName(fileName: string): boolean {
let pattern: string = '^[a-z0-9A-Z\.-_]+$';
let pattern: string = '^[a-z0-9A-Z\._-]+$';
const validateResult = fileName.match(pattern);
if(validateResult) {
return true;
......
......@@ -51,6 +51,7 @@ export namespace ValidationSchemas {
command: joi.string().min(1),
virtualCluster: joi.string(),
shmMB: joi.number(),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode'),
worker: joi.object({
replicas: joi.number().min(1).required(),
image: joi.string().min(1),
......
......@@ -58,6 +58,10 @@ export abstract class ClusterJobRestServer extends RestServer {
this.port = basePort + 1;
}
get apiRootUrl(): string {
return this.API_ROOT_URL;
}
public get clusterRestServerPort(): number {
if (this.port === undefined) {
throw new Error('PAI Rest server port is undefined');
......@@ -87,7 +91,7 @@ export abstract class ClusterJobRestServer extends RestServer {
protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;
// tslint:disable: no-unsafe-any no-any
private createRestHandler() : Router {
protected createRestHandler() : Router {
const router: Router = Router();
router.use((req: Request, res: Response, next: any) => {
......
......@@ -355,7 +355,8 @@ class LocalTrainingService implements TrainingService {
this.log.info('Stopping local machine training service...');
this.stopping = true;
for (const stream of this.jobStreamMap.values()) {
stream.destroy();
stream.end(0)
stream.emit('end')
}
if (this.gpuScheduler !== undefined) {
await this.gpuScheduler.stop();
......@@ -372,7 +373,9 @@ class LocalTrainingService implements TrainingService {
if (stream === undefined) {
throw new Error(`Could not find stream in trial ${trialJob.id}`);
}
stream.destroy();
//Refer https://github.com/Juul/tail-stream/issues/20
stream.end(0)
stream.emit('end')
this.jobStreamMap.delete(trialJob.id);
}
}
......@@ -567,7 +570,6 @@ class LocalTrainingService implements TrainingService {
buffer = remain;
}
});
this.jobStreamMap.set(trialJobDetail.id, stream);
}
......
......@@ -64,11 +64,11 @@ else
fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} \
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' \
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{12}' --log_collection '{13}'`;
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \
--pai_hdfs_output_dir '{9}' --pai_hdfs_host '{10}' --pai_user_name {11} --nni_hdfs_exp_dir '{12}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{13}' --log_collection '{14}'`;
export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`;
......
......@@ -19,17 +19,26 @@
'use strict';
import { Request, Response, Router } from 'express';
import { Inject } from 'typescript-ioc';
import * as component from '../../common/component';
import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { PAITrainingService } from './paiTrainingService';
export interface ParameterFileMeta {
readonly experimentId: string;
readonly trialId: string;
readonly filePath: string;
}
/**
* PAI Training service Rest server, provides rest API to support pai job metrics update
*
*/
@component.Singleton
export class PAIJobRestServer extends ClusterJobRestServer {
private parameterFileMetaList: ParameterFileMeta[] = [];
@Inject
private readonly paiTrainingService : PAITrainingService;
......@@ -52,4 +61,33 @@ export class PAIJobRestServer extends ClusterJobRestServer {
});
}
}
protected createRestHandler(): Router {
const router: Router = super.createRestHandler();
router.post(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info(`POST /parameter-file-meta, body is ${JSON.stringify(req.body)}`);
this.parameterFileMetaList.push(req.body);
res.send();
} catch (err) {
this.log.error(`POST parameter-file-meta error: ${err}`);
res.status(500);
res.send(err.message);
}
});
router.get(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info(`GET /parameter-file-meta`);
res.send(this.parameterFileMetaList);
} catch (err) {
this.log.error(`GET parameter-file-meta error: ${err}`);
res.status(500);
res.send(err.message);
}
});
return router;
}
}
......@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import {
JobApplicationForm, NNIManagerIpConfig, TrainingService,
HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService';
import { delay, generateParamFileName,
......@@ -45,7 +45,7 @@ import { HDFSClientUtility } from './hdfsClientUtility';
import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig';
import { PAI_LOG_PATH_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAITrialJobDetail } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer';
import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer';
import * as WebHDFS from 'webhdfs';
......@@ -79,6 +79,7 @@ class PAITrainingService implements TrainingService {
private copyExpCodeDirPromise?: Promise<void>;
private versionCheck: boolean = true;
private logCollection: string;
private isMultiPhase: boolean = false;
constructor() {
this.log = getLogger();
......@@ -179,12 +180,22 @@ class PAITrainingService implements TrainingService {
return deferred.promise;
}
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError();
public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail;
}
public get isMultiPhaseJobSupported(): boolean {
return false;
return true;
}
// tslint:disable:no-http-string
......@@ -336,6 +347,9 @@ class PAITrainingService implements TrainingService {
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
throw new Error(`Uknown key: ${key}`);
......@@ -445,6 +459,7 @@ class PAITrainingService implements TrainingService {
trialJobId,
this.experimentId,
trialJobDetail.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command,
nniManagerIp,
this.paiRestServerPort,
......@@ -632,7 +647,50 @@ class PAITrainingService implements TrainingService {
return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); });
}
// tslint:enable:no-any no-unsafe-any no-http-string
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('PAI trial config is not initialized');
}
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const hpFileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, hpFileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsHpFilePath: string = path.join(hdfsCodeDir, hpFileName);
await HDFSClientUtility.copyFileToHdfs(localFilepath, hdfsHpFilePath, this.hdfsClient);
await this.postParameterFileMeta({
experimentId: this.experimentId,
trialId: trialJobId,
filePath: hdfsHpFilePath
});
}
private postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
const req: request.Options = {
uri: `${restServer.endPoint}${restServer.apiRootUrl}/parameter-file-meta`,
method: 'POST',
json: true,
body: parameterFileMeta
};
request(req, (err: Error, res: request.Response) => {
if (err) {
deferred.reject(err);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
}
export { PAITrainingService };
declare module 'tail-stream' {
export interface Stream {
on(type: 'data', callback: (data: Buffer) => void): void;
destroy(): void;
end(data: number): void;
emit(data: string): void;
}
export function createReadStream(path: string): Stream;
}
\ No newline at end of file
......@@ -23,6 +23,7 @@
from .trial import *
from .smartparam import *
from .nas_utils import reload_tensorflow_variables
class NoMoreTrialError(Exception):
def __init__(self,ErrorInfo):
......
......@@ -28,9 +28,8 @@ import json
import importlib
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
from nni.common import enable_multi_thread
from nni.common import enable_multi_thread, enable_multi_phase
from nni.msg_dispatcher import MsgDispatcher
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
logger = logging.getLogger('nni.main')
logger.debug('START')
......@@ -126,6 +125,8 @@ def main():
args = parse_args()
if args.multi_thread:
enable_multi_thread()
if args.multi_phase:
enable_multi_phase()
if args.advisor_class_name:
# advisor is enabled and starts to run
......@@ -180,9 +181,6 @@ def main():
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
if args.multi_phase:
dispatcher = MultiPhaseMsgDispatcher(tuner, assessor)
else:
dispatcher = MsgDispatcher(tuner, assessor)
try:
......
......@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
"""
self.values = self.is_valid(search_space)
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
......@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
raise nni.NoMoreTrialError('no more parameters now.')
return self.values[self.count]
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass
def import_data(self, data):
......
......@@ -106,7 +106,7 @@ class Bracket():
self.s_max = s_max
self.eta = eta
self.max_budget = max_budget
self.optimize_mode = optimize_mode
self.optimize_mode = OptimizeMode(optimize_mode)
self.n = math.ceil((s_max + 1) * eta**s / (s + 1) - _epsilon)
self.r = max_budget / eta**s
......
......@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
sys.stdout = _LoggerFileWrapper(logger_file)
_multi_thread = False
_multi_phase = False
def enable_multi_thread():
global _multi_thread
......@@ -76,3 +77,10 @@ def enable_multi_thread():
def multi_thread_enabled():
return _multi_thread
def enable_multi_phase():
global _multi_phase
_multi_phase = True
def multi_phase_enabled():
return _multi_phase
......@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
self.searchspace_json, is_rand, self.random_state)
self.population.append(Individual(config=config))
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters
......@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
config = split_index(total_config)
return config
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''Record the result from a trial
Parameters
......
......@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
'''
self.expanded_search_space = self.json2parameter(search_space)
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
self.count += 1
while (self.count <= len(self.expanded_search_space)-1):
_params_tuple = convert_dict2tuple(self.expanded_search_space[self.count])
......@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
return self.expanded_search_space[self.count]
raise nni.NoMoreTrialError('no more parameters now.')
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass
def import_data(self, data):
......
......@@ -144,7 +144,7 @@ class Bracket():
self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ]
self.num_configs_to_run = [] # [ n, n, n, ... ]
self.num_finished_configs = [] # [ n, n, n, ... ]
self.optimize_mode = optimize_mode
self.optimize_mode = OptimizeMode(optimize_mode)
self.no_more_trial = False
def is_completed(self):
......
......@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
verbose=0)
self.rval.catch_eval_exceptions = False
def generate_parameters(self, parameter_id):
def generate_parameters(self, parameter_id, **kwargs):
"""
Returns a set of trial (hyper-)parameters, as a serializable object.
......@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
params = split_index(total_params)
return params
def receive_trial_result(self, parameter_id, parameters, value):
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Record an observation of the objective function
......
......@@ -49,15 +49,16 @@ def selection_r(x_bounds,
num_starting_points=100,
minimize_constraints_fun=None):
'''
Call selection
Select using different types.
'''
minimize_starting_points = [lib_data.rand(x_bounds, x_types)\
for i in range(0, num_starting_points)]
minimize_starting_points = clusteringmodel_gmm_good.sample(n_samples=num_starting_points)
outputs = selection(x_bounds, x_types,
clusteringmodel_gmm_good,
clusteringmodel_gmm_bad,
minimize_starting_points,
minimize_starting_points[0],
minimize_constraints_fun)
return outputs
def selection(x_bounds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment