Commit d6febf29 authored by suiguoxin's avatar suiguoxin
Browse files

Merge branch 'master' of git://github.com/microsoft/nni

parents 77c95479 c2179921
...@@ -49,12 +49,12 @@ class RandomNASTuner(Tuner): ...@@ -49,12 +49,12 @@ class RandomNASTuner(Tuner):
self.searchspace_json = search_space self.searchspace_json = search_space
self.random_state = np.random.RandomState() self.random_state = np.random.RandomState()
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
'''generate '''generate
''' '''
return random_archi_generator(self.searchspace_json, self.random_state) return random_archi_generator(self.searchspace_json, self.random_state)
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''receive '''receive
''' '''
pass pass
...@@ -112,7 +112,7 @@ class CustomerTuner(Tuner): ...@@ -112,7 +112,7 @@ class CustomerTuner(Tuner):
population.append(Individual(indiv_id=self.generate_new_id(), graph_cfg=graph_tmp, result=None)) population.append(Individual(indiv_id=self.generate_new_id(), graph_cfg=graph_tmp, result=None))
return population return population
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial graph config, as a serializable object. """Returns a set of trial graph config, as a serializable object.
An example configuration: An example configuration:
```json ```json
...@@ -196,7 +196,7 @@ class CustomerTuner(Tuner): ...@@ -196,7 +196,7 @@ class CustomerTuner(Tuner):
logger.debug("trial {} ready".format(indiv.indiv_id)) logger.debug("trial {} ready".format(indiv.indiv_id))
return param_json return param_json
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
''' '''
Record an observation of the objective function Record an observation of the objective function
parameter_id : int parameter_id : int
......
...@@ -375,7 +375,7 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number): ...@@ -375,7 +375,7 @@ function countFilesRecursively(directory: string, timeoutMilliSeconds?: number):
} }
function validateFileName(fileName: string): boolean { function validateFileName(fileName: string): boolean {
let pattern: string = '^[a-z0-9A-Z\.-_]+$'; let pattern: string = '^[a-z0-9A-Z\._-]+$';
const validateResult = fileName.match(pattern); const validateResult = fileName.match(pattern);
if(validateResult) { if(validateResult) {
return true; return true;
......
...@@ -51,6 +51,7 @@ export namespace ValidationSchemas { ...@@ -51,6 +51,7 @@ export namespace ValidationSchemas {
command: joi.string().min(1), command: joi.string().min(1),
virtualCluster: joi.string(), virtualCluster: joi.string(),
shmMB: joi.number(), shmMB: joi.number(),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode'),
worker: joi.object({ worker: joi.object({
replicas: joi.number().min(1).required(), replicas: joi.number().min(1).required(),
image: joi.string().min(1), image: joi.string().min(1),
......
...@@ -58,6 +58,10 @@ export abstract class ClusterJobRestServer extends RestServer { ...@@ -58,6 +58,10 @@ export abstract class ClusterJobRestServer extends RestServer {
this.port = basePort + 1; this.port = basePort + 1;
} }
get apiRootUrl(): string {
return this.API_ROOT_URL;
}
public get clusterRestServerPort(): number { public get clusterRestServerPort(): number {
if (this.port === undefined) { if (this.port === undefined) {
throw new Error('PAI Rest server port is undefined'); throw new Error('PAI Rest server port is undefined');
...@@ -87,7 +91,7 @@ export abstract class ClusterJobRestServer extends RestServer { ...@@ -87,7 +91,7 @@ export abstract class ClusterJobRestServer extends RestServer {
protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void; protected abstract handleTrialMetrics(jobId : string, trialMetrics : any[]) : void;
// tslint:disable: no-unsafe-any no-any // tslint:disable: no-unsafe-any no-any
private createRestHandler() : Router { protected createRestHandler() : Router {
const router: Router = Router(); const router: Router = Router();
router.use((req: Request, res: Response, next: any) => { router.use((req: Request, res: Response, next: any) => {
......
...@@ -355,7 +355,8 @@ class LocalTrainingService implements TrainingService { ...@@ -355,7 +355,8 @@ class LocalTrainingService implements TrainingService {
this.log.info('Stopping local machine training service...'); this.log.info('Stopping local machine training service...');
this.stopping = true; this.stopping = true;
for (const stream of this.jobStreamMap.values()) { for (const stream of this.jobStreamMap.values()) {
stream.destroy(); stream.end(0)
stream.emit('end')
} }
if (this.gpuScheduler !== undefined) { if (this.gpuScheduler !== undefined) {
await this.gpuScheduler.stop(); await this.gpuScheduler.stop();
...@@ -372,7 +373,9 @@ class LocalTrainingService implements TrainingService { ...@@ -372,7 +373,9 @@ class LocalTrainingService implements TrainingService {
if (stream === undefined) { if (stream === undefined) {
throw new Error(`Could not find stream in trial ${trialJob.id}`); throw new Error(`Could not find stream in trial ${trialJob.id}`);
} }
stream.destroy(); //Refer https://github.com/Juul/tail-stream/issues/20
stream.end(0)
stream.emit('end')
this.jobStreamMap.delete(trialJob.id); this.jobStreamMap.delete(trialJob.id);
} }
} }
...@@ -567,7 +570,6 @@ class LocalTrainingService implements TrainingService { ...@@ -567,7 +570,6 @@ class LocalTrainingService implements TrainingService {
buffer = remain; buffer = remain;
} }
}); });
this.jobStreamMap.set(trialJobDetail.id, stream); this.jobStreamMap.set(trialJobDetail.id, stream);
} }
......
...@@ -64,11 +64,11 @@ else ...@@ -64,11 +64,11 @@ else
fi`; fi`;
export const PAI_TRIAL_COMMAND_FORMAT: string = export const PAI_TRIAL_COMMAND_FORMAT: string =
`export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} \ `export NNI_PLATFORM=pai NNI_SYS_DIR={0} NNI_OUTPUT_DIR={1} NNI_TRIAL_JOB_ID={2} NNI_EXP_ID={3} NNI_TRIAL_SEQ_ID={4} MULTI_PHASE={5} \
&& cd $NNI_SYS_DIR && sh install_nni.sh \ && cd $NNI_SYS_DIR && sh install_nni.sh \
&& python3 -m nni_trial_tool.trial_keeper --trial_command '{5}' --nnimanager_ip '{6}' --nnimanager_port '{7}' \ && python3 -m nni_trial_tool.trial_keeper --trial_command '{6}' --nnimanager_ip '{7}' --nnimanager_port '{8}' \
--pai_hdfs_output_dir '{8}' --pai_hdfs_host '{9}' --pai_user_name {10} --nni_hdfs_exp_dir '{11}' --webhdfs_path '/webhdfs/api/v1' \ --pai_hdfs_output_dir '{9}' --pai_hdfs_host '{10}' --pai_user_name {11} --nni_hdfs_exp_dir '{12}' --webhdfs_path '/webhdfs/api/v1' \
--nni_manager_version '{12}' --log_collection '{13}'`; --nni_manager_version '{13}' --log_collection '{14}'`;
export const PAI_OUTPUT_DIR_FORMAT: string = export const PAI_OUTPUT_DIR_FORMAT: string =
`hdfs://{0}:9000/`; `hdfs://{0}:9000/`;
......
...@@ -19,17 +19,26 @@ ...@@ -19,17 +19,26 @@
'use strict'; 'use strict';
import { Request, Response, Router } from 'express';
import { Inject } from 'typescript-ioc'; import { Inject } from 'typescript-ioc';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { ClusterJobRestServer } from '../common/clusterJobRestServer'; import { ClusterJobRestServer } from '../common/clusterJobRestServer';
import { PAITrainingService } from './paiTrainingService'; import { PAITrainingService } from './paiTrainingService';
export interface ParameterFileMeta {
readonly experimentId: string;
readonly trialId: string;
readonly filePath: string;
}
/** /**
* PAI Training service Rest server, provides rest API to support pai job metrics update * PAI Training service Rest server, provides rest API to support pai job metrics update
* *
*/ */
@component.Singleton @component.Singleton
export class PAIJobRestServer extends ClusterJobRestServer { export class PAIJobRestServer extends ClusterJobRestServer {
private parameterFileMetaList: ParameterFileMeta[] = [];
@Inject @Inject
private readonly paiTrainingService : PAITrainingService; private readonly paiTrainingService : PAITrainingService;
...@@ -52,4 +61,33 @@ export class PAIJobRestServer extends ClusterJobRestServer { ...@@ -52,4 +61,33 @@ export class PAIJobRestServer extends ClusterJobRestServer {
}); });
} }
} }
protected createRestHandler(): Router {
const router: Router = super.createRestHandler();
router.post(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info(`POST /parameter-file-meta, body is ${JSON.stringify(req.body)}`);
this.parameterFileMetaList.push(req.body);
res.send();
} catch (err) {
this.log.error(`POST parameter-file-meta error: ${err}`);
res.status(500);
res.send(err.message);
}
});
router.get(`/parameter-file-meta`, (req: Request, res: Response) => {
try {
this.log.info(`GET /parameter-file-meta`);
res.send(this.parameterFileMetaList);
} catch (err) {
this.log.error(`GET parameter-file-meta error: ${err}`);
res.status(500);
res.send(err.message);
}
});
return router;
}
} }
...@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors'; ...@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo'; import { getExperimentId, getInitTrialSequenceId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { import {
JobApplicationForm, NNIManagerIpConfig, TrainingService, HyperParameters, JobApplicationForm, NNIManagerIpConfig, TrainingService,
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, import { delay, generateParamFileName,
...@@ -45,7 +45,7 @@ import { HDFSClientUtility } from './hdfsClientUtility'; ...@@ -45,7 +45,7 @@ import { HDFSClientUtility } from './hdfsClientUtility';
import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig'; import { NNIPAITrialConfig, PAIClusterConfig, PAIJobConfig, PAITaskRole } from './paiConfig';
import { PAI_LOG_PATH_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAITrialJobDetail } from './paiData'; import { PAI_LOG_PATH_FORMAT, PAI_OUTPUT_DIR_FORMAT, PAI_TRIAL_COMMAND_FORMAT, PAITrialJobDetail } from './paiData';
import { PAIJobInfoCollector } from './paiJobInfoCollector'; import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer'; import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer';
import * as WebHDFS from 'webhdfs'; import * as WebHDFS from 'webhdfs';
...@@ -79,6 +79,7 @@ class PAITrainingService implements TrainingService { ...@@ -79,6 +79,7 @@ class PAITrainingService implements TrainingService {
private copyExpCodeDirPromise?: Promise<void>; private copyExpCodeDirPromise?: Promise<void>;
private versionCheck: boolean = true; private versionCheck: boolean = true;
private logCollection: string; private logCollection: string;
private isMultiPhase: boolean = false;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
...@@ -179,12 +180,22 @@ class PAITrainingService implements TrainingService { ...@@ -179,12 +180,22 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
public updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> { public async updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail> {
throw new MethodNotImplementedError(); const trialJobDetail: undefined | TrialJobDetail = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) {
throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
}
if (form.jobType === 'TRIAL') {
await this.writeParameterFile(trialJobId, (<TrialJobApplicationForm>form).hyperParameters);
} else {
throw new Error(`updateTrialJob failed: jobType ${form.jobType} not supported.`);
}
return trialJobDetail;
} }
public get isMultiPhaseJobSupported(): boolean { public get isMultiPhaseJobSupported(): boolean {
return false; return true;
} }
// tslint:disable:no-http-string // tslint:disable:no-http-string
...@@ -336,6 +347,9 @@ class PAITrainingService implements TrainingService { ...@@ -336,6 +347,9 @@ class PAITrainingService implements TrainingService {
case TrialConfigMetadataKey.LOG_COLLECTION: case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value; this.logCollection = value;
break; break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default: default:
//Reject for unknown keys //Reject for unknown keys
throw new Error(`Uknown key: ${key}`); throw new Error(`Uknown key: ${key}`);
...@@ -445,6 +459,7 @@ class PAITrainingService implements TrainingService { ...@@ -445,6 +459,7 @@ class PAITrainingService implements TrainingService {
trialJobId, trialJobId,
this.experimentId, this.experimentId,
trialJobDetail.sequenceId, trialJobDetail.sequenceId,
this.isMultiPhase,
this.paiTrialConfig.command, this.paiTrialConfig.command,
nniManagerIp, nniManagerIp,
this.paiRestServerPort, this.paiRestServerPort,
...@@ -632,7 +647,50 @@ class PAITrainingService implements TrainingService { ...@@ -632,7 +647,50 @@ class PAITrainingService implements TrainingService {
return Promise.race([timeoutDelay, deferred.promise]) return Promise.race([timeoutDelay, deferred.promise])
.finally(() => { clearTimeout(timeoutId); }); .finally(() => { clearTimeout(timeoutId); });
} }
// tslint:enable:no-any no-unsafe-any no-http-string
private async writeParameterFile(trialJobId: string, hyperParameters: HyperParameters): Promise<void> {
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('PAI trial config is not initialized');
}
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
const hpFileName: string = generateParamFileName(hyperParameters);
const localFilepath: string = path.join(trialLocalTempFolder, hpFileName);
await fs.promises.writeFile(localFilepath, hyperParameters.value, { encoding: 'utf8' });
const hdfsCodeDir: string = HDFSClientUtility.getHdfsTrialWorkDir(this.paiClusterConfig.userName, trialJobId);
const hdfsHpFilePath: string = path.join(hdfsCodeDir, hpFileName);
await HDFSClientUtility.copyFileToHdfs(localFilepath, hdfsHpFilePath, this.hdfsClient);
await this.postParameterFileMeta({
experimentId: this.experimentId,
trialId: trialJobId,
filePath: hdfsHpFilePath
});
}
private postParameterFileMeta(parameterFileMeta: ParameterFileMeta): Promise<void> {
const deferred : Deferred<void> = new Deferred<void>();
const restServer: PAIJobRestServer = component.get(PAIJobRestServer);
const req: request.Options = {
uri: `${restServer.endPoint}${restServer.apiRootUrl}/parameter-file-meta`,
method: 'POST',
json: true,
body: parameterFileMeta
};
request(req, (err: Error, res: request.Response) => {
if (err) {
deferred.reject(err);
} else {
deferred.resolve();
}
});
return deferred.promise;
}
} }
export { PAITrainingService }; export { PAITrainingService };
declare module 'tail-stream' { declare module 'tail-stream' {
export interface Stream { export interface Stream {
on(type: 'data', callback: (data: Buffer) => void): void; on(type: 'data', callback: (data: Buffer) => void): void;
destroy(): void; end(data: number): void;
emit(data: string): void;
} }
export function createReadStream(path: string): Stream; export function createReadStream(path: string): Stream;
} }
\ No newline at end of file
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
from .trial import * from .trial import *
from .smartparam import * from .smartparam import *
from .nas_utils import reload_tensorflow_variables
class NoMoreTrialError(Exception): class NoMoreTrialError(Exception):
def __init__(self,ErrorInfo): def __init__(self,ErrorInfo):
......
...@@ -28,9 +28,8 @@ import json ...@@ -28,9 +28,8 @@ import json
import importlib import importlib
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
from nni.common import enable_multi_thread from nni.common import enable_multi_thread, enable_multi_phase
from nni.msg_dispatcher import MsgDispatcher from nni.msg_dispatcher import MsgDispatcher
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
logger = logging.getLogger('nni.main') logger = logging.getLogger('nni.main')
logger.debug('START') logger.debug('START')
...@@ -126,6 +125,8 @@ def main(): ...@@ -126,6 +125,8 @@ def main():
args = parse_args() args = parse_args()
if args.multi_thread: if args.multi_thread:
enable_multi_thread() enable_multi_thread()
if args.multi_phase:
enable_multi_phase()
if args.advisor_class_name: if args.advisor_class_name:
# advisor is enabled and starts to run # advisor is enabled and starts to run
...@@ -180,10 +181,7 @@ def main(): ...@@ -180,10 +181,7 @@ def main():
if assessor is None: if assessor is None:
raise AssertionError('Failed to create Assessor instance') raise AssertionError('Failed to create Assessor instance')
if args.multi_phase: dispatcher = MsgDispatcher(tuner, assessor)
dispatcher = MultiPhaseMsgDispatcher(tuner, assessor)
else:
dispatcher = MsgDispatcher(tuner, assessor)
try: try:
dispatcher.run() dispatcher.run()
......
...@@ -78,7 +78,7 @@ class BatchTuner(Tuner): ...@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
""" """
self.values = self.is_valid(search_space) self.values = self.is_valid(search_space)
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object. """Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters Parameters
...@@ -90,7 +90,7 @@ class BatchTuner(Tuner): ...@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
raise nni.NoMoreTrialError('no more parameters now.') raise nni.NoMoreTrialError('no more parameters now.')
return self.values[self.count] return self.values[self.count]
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass pass
def import_data(self, data): def import_data(self, data):
......
...@@ -106,7 +106,7 @@ class Bracket(): ...@@ -106,7 +106,7 @@ class Bracket():
self.s_max = s_max self.s_max = s_max
self.eta = eta self.eta = eta
self.max_budget = max_budget self.max_budget = max_budget
self.optimize_mode = optimize_mode self.optimize_mode = OptimizeMode(optimize_mode)
self.n = math.ceil((s_max + 1) * eta**s / (s + 1) - _epsilon) self.n = math.ceil((s_max + 1) * eta**s / (s + 1) - _epsilon)
self.r = max_budget / eta**s self.r = max_budget / eta**s
......
...@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'): ...@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
sys.stdout = _LoggerFileWrapper(logger_file) sys.stdout = _LoggerFileWrapper(logger_file)
_multi_thread = False _multi_thread = False
_multi_phase = False
def enable_multi_thread(): def enable_multi_thread():
global _multi_thread global _multi_thread
...@@ -76,3 +77,10 @@ def enable_multi_thread(): ...@@ -76,3 +77,10 @@ def enable_multi_thread():
def multi_thread_enabled(): def multi_thread_enabled():
return _multi_thread return _multi_thread
def enable_multi_phase():
global _multi_phase
_multi_phase = True
def multi_phase_enabled():
return _multi_phase
...@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner): ...@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
self.searchspace_json, is_rand, self.random_state) self.searchspace_json, is_rand, self.random_state)
self.population.append(Individual(config=config)) self.population.append(Individual(config=config))
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object. """Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters Parameters
...@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner): ...@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
config = split_index(total_config) config = split_index(total_config)
return config return config
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''Record the result from a trial '''Record the result from a trial
Parameters Parameters
......
...@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner): ...@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
''' '''
self.expanded_search_space = self.json2parameter(search_space) self.expanded_search_space = self.json2parameter(search_space)
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
self.count += 1 self.count += 1
while (self.count <= len(self.expanded_search_space)-1): while (self.count <= len(self.expanded_search_space)-1):
_params_tuple = convert_dict2tuple(self.expanded_search_space[self.count]) _params_tuple = convert_dict2tuple(self.expanded_search_space[self.count])
...@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner): ...@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
return self.expanded_search_space[self.count] return self.expanded_search_space[self.count]
raise nni.NoMoreTrialError('no more parameters now.') raise nni.NoMoreTrialError('no more parameters now.')
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass pass
def import_data(self, data): def import_data(self, data):
......
...@@ -144,7 +144,7 @@ class Bracket(): ...@@ -144,7 +144,7 @@ class Bracket():
self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ] self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ]
self.num_configs_to_run = [] # [ n, n, n, ... ] self.num_configs_to_run = [] # [ n, n, n, ... ]
self.num_finished_configs = [] # [ n, n, n, ... ] self.num_finished_configs = [] # [ n, n, n, ... ]
self.optimize_mode = optimize_mode self.optimize_mode = OptimizeMode(optimize_mode)
self.no_more_trial = False self.no_more_trial = False
def is_completed(self): def is_completed(self):
......
...@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner): ...@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
verbose=0) verbose=0)
self.rval.catch_eval_exceptions = False self.rval.catch_eval_exceptions = False
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
""" """
Returns a set of trial (hyper-)parameters, as a serializable object. Returns a set of trial (hyper-)parameters, as a serializable object.
...@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner): ...@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
params = split_index(total_params) params = split_index(total_params)
return params return params
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
""" """
Record an observation of the objective function Record an observation of the objective function
......
...@@ -49,15 +49,16 @@ def selection_r(x_bounds, ...@@ -49,15 +49,16 @@ def selection_r(x_bounds,
num_starting_points=100, num_starting_points=100,
minimize_constraints_fun=None): minimize_constraints_fun=None):
''' '''
Call selection Select using different types.
''' '''
minimize_starting_points = [lib_data.rand(x_bounds, x_types)\ minimize_starting_points = clusteringmodel_gmm_good.sample(n_samples=num_starting_points)
for i in range(0, num_starting_points)]
outputs = selection(x_bounds, x_types, outputs = selection(x_bounds, x_types,
clusteringmodel_gmm_good, clusteringmodel_gmm_good,
clusteringmodel_gmm_bad, clusteringmodel_gmm_bad,
minimize_starting_points, minimize_starting_points[0],
minimize_constraints_fun) minimize_constraints_fun)
return outputs return outputs
def selection(x_bounds, def selection(x_bounds,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment