"vscode:/vscode.git/clone" did not exist on "e5a208bac69d34469652430fe42c220ed47c5a72"
Unverified Commit 1500458a authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #187 from microsoft/master

merge master
parents 93dd76ba 97829ccd
declare module 'tail-stream' { declare module 'tail-stream' {
export interface Stream { export interface Stream {
on(type: 'data', callback: (data: Buffer) => void): void; on(type: 'data', callback: (data: Buffer) => void): void;
destroy(): void; end(data: number): void;
emit(data: string): void;
} }
export function createReadStream(path: string): Stream; export function createReadStream(path: string): Stream;
} }
\ No newline at end of file
...@@ -28,9 +28,8 @@ import json ...@@ -28,9 +28,8 @@ import json
import importlib import importlib
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
from nni.common import enable_multi_thread from nni.common import enable_multi_thread, enable_multi_phase
from nni.msg_dispatcher import MsgDispatcher from nni.msg_dispatcher import MsgDispatcher
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
logger = logging.getLogger('nni.main') logger = logging.getLogger('nni.main')
logger.debug('START') logger.debug('START')
...@@ -126,6 +125,8 @@ def main(): ...@@ -126,6 +125,8 @@ def main():
args = parse_args() args = parse_args()
if args.multi_thread: if args.multi_thread:
enable_multi_thread() enable_multi_thread()
if args.multi_phase:
enable_multi_phase()
if args.advisor_class_name: if args.advisor_class_name:
# advisor is enabled and starts to run # advisor is enabled and starts to run
...@@ -180,10 +181,7 @@ def main(): ...@@ -180,10 +181,7 @@ def main():
if assessor is None: if assessor is None:
raise AssertionError('Failed to create Assessor instance') raise AssertionError('Failed to create Assessor instance')
if args.multi_phase: dispatcher = MsgDispatcher(tuner, assessor)
dispatcher = MultiPhaseMsgDispatcher(tuner, assessor)
else:
dispatcher = MsgDispatcher(tuner, assessor)
try: try:
dispatcher.run() dispatcher.run()
......
...@@ -78,7 +78,7 @@ class BatchTuner(Tuner): ...@@ -78,7 +78,7 @@ class BatchTuner(Tuner):
""" """
self.values = self.is_valid(search_space) self.values = self.is_valid(search_space)
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object. """Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters Parameters
...@@ -90,7 +90,7 @@ class BatchTuner(Tuner): ...@@ -90,7 +90,7 @@ class BatchTuner(Tuner):
raise nni.NoMoreTrialError('no more parameters now.') raise nni.NoMoreTrialError('no more parameters now.')
return self.values[self.count] return self.values[self.count]
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass pass
def import_data(self, data): def import_data(self, data):
......
...@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'): ...@@ -69,6 +69,7 @@ def init_logger(logger_file_path, log_level_name='info'):
sys.stdout = _LoggerFileWrapper(logger_file) sys.stdout = _LoggerFileWrapper(logger_file)
_multi_thread = False _multi_thread = False
_multi_phase = False
def enable_multi_thread(): def enable_multi_thread():
global _multi_thread global _multi_thread
...@@ -76,3 +77,10 @@ def enable_multi_thread(): ...@@ -76,3 +77,10 @@ def enable_multi_thread():
def multi_thread_enabled(): def multi_thread_enabled():
return _multi_thread return _multi_thread
def enable_multi_phase():
global _multi_phase
_multi_phase = True
def multi_phase_enabled():
return _multi_phase
...@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner): ...@@ -188,7 +188,7 @@ class EvolutionTuner(Tuner):
self.searchspace_json, is_rand, self.random_state) self.searchspace_json, is_rand, self.random_state)
self.population.append(Individual(config=config)) self.population.append(Individual(config=config))
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
"""Returns a dict of trial (hyper-)parameters, as a serializable object. """Returns a dict of trial (hyper-)parameters, as a serializable object.
Parameters Parameters
...@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner): ...@@ -232,7 +232,7 @@ class EvolutionTuner(Tuner):
config = split_index(total_config) config = split_index(total_config)
return config return config
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
'''Record the result from a trial '''Record the result from a trial
Parameters Parameters
......
...@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner): ...@@ -137,7 +137,7 @@ class GridSearchTuner(Tuner):
''' '''
self.expanded_search_space = self.json2parameter(search_space) self.expanded_search_space = self.json2parameter(search_space)
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
self.count += 1 self.count += 1
while (self.count <= len(self.expanded_search_space)-1): while (self.count <= len(self.expanded_search_space)-1):
_params_tuple = convert_dict2tuple(self.expanded_search_space[self.count]) _params_tuple = convert_dict2tuple(self.expanded_search_space[self.count])
...@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner): ...@@ -147,7 +147,7 @@ class GridSearchTuner(Tuner):
return self.expanded_search_space[self.count] return self.expanded_search_space[self.count]
raise nni.NoMoreTrialError('no more parameters now.') raise nni.NoMoreTrialError('no more parameters now.')
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
pass pass
def import_data(self, data): def import_data(self, data):
......
...@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner): ...@@ -248,7 +248,7 @@ class HyperoptTuner(Tuner):
verbose=0) verbose=0)
self.rval.catch_eval_exceptions = False self.rval.catch_eval_exceptions = False
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
""" """
Returns a set of trial (hyper-)parameters, as a serializable object. Returns a set of trial (hyper-)parameters, as a serializable object.
...@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner): ...@@ -269,7 +269,7 @@ class HyperoptTuner(Tuner):
params = split_index(total_params) params = split_index(total_params)
return params return params
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
""" """
Record an observation of the objective function Record an observation of the objective function
......
...@@ -174,7 +174,7 @@ class MetisTuner(Tuner): ...@@ -174,7 +174,7 @@ class MetisTuner(Tuner):
return output return output
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
"""Generate next parameter for trial """Generate next parameter for trial
If the number of trial result is lower than cold start number, If the number of trial result is lower than cold start number,
metis will first random generate some parameters. metis will first random generate some parameters.
...@@ -205,7 +205,7 @@ class MetisTuner(Tuner): ...@@ -205,7 +205,7 @@ class MetisTuner(Tuner):
return results return results
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Tuner receive result from trial. """Tuner receive result from trial.
Parameters Parameters
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ================================================================================================== # ==================================================================================================
import os
import logging import logging
from collections import defaultdict from collections import defaultdict
import json_tricks import json_tricks
...@@ -26,7 +25,7 @@ import json_tricks ...@@ -26,7 +25,7 @@ import json_tricks
from .protocol import CommandType, send from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult from .assessor import AssessResult
from .common import multi_thread_enabled from .common import multi_thread_enabled, multi_phase_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -61,13 +60,19 @@ def _create_parameter_id(): ...@@ -61,13 +60,19 @@ def _create_parameter_id():
_next_parameter_id += 1 _next_parameter_id += 1
return _next_parameter_id - 1 return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False): def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params _trial_params[parameter_id] = params
ret = { ret = {
'parameter_id': parameter_id, 'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm', 'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params 'parameters': params
} }
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret) return json_tricks.dumps(ret)
class MsgDispatcher(MsgDispatcherBase): class MsgDispatcher(MsgDispatcherBase):
...@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -133,8 +138,13 @@ class MsgDispatcher(MsgDispatcherBase):
elif data['type'] == 'PERIODICAL': elif data['type'] == 'PERIODICAL':
if self.assessor is not None: if self.assessor is not None:
self._handle_intermediate_metric_data(data) self._handle_intermediate_metric_data(data)
else: elif data['type'] == 'REQUEST_PARAMETER':
pass assert multi_phase_enabled()
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else: else:
raise ValueError('Data type not supported: {}'.format(data['type'])) raise ValueError('Data type not supported: {}'.format(data['type']))
...@@ -160,9 +170,15 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -160,9 +170,15 @@ class MsgDispatcher(MsgDispatcherBase):
id_ = data['parameter_id'] id_ = data['parameter_id']
value = data['value'] value = data['value']
if id_ in _customized_parameter_ids: if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value) if multi_phase_enabled():
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
else:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], value)
else: else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value) if multi_phase_enabled():
self.tuner.receive_trial_result(id_, _trial_params[id_], value, trial_job_id=data['trial_job_id'])
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], value)
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results """Call assessor to process intermediate results
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from collections import defaultdict
import json_tricks
from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.assessor import AssessResult
_logger = logging.getLogger(__name__)
# Assessor global variables
_trial_history = defaultdict(dict)
'''key: trial job ID; value: intermediate results, mapping from sequence number to data'''
_ended_trials = set()
'''trial_job_id of all ended trials.
We need this because NNI manager may send metrics after reporting a trial ended.
TODO: move this logic to NNI manager
'''
def _sort_history(history):
ret = [ ]
for i, _ in enumerate(history):
if i in history:
ret.append(history[i])
else:
break
return ret
# Tuner global variables
_next_parameter_id = 0
_trial_params = {}
'''key: trial job ID; value: parameters'''
_customized_parameter_ids = set()
def _create_parameter_id():
global _next_parameter_id # pylint: disable=global-statement
_next_parameter_id += 1
return _next_parameter_id - 1
def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, parameter_index=None):
_trial_params[parameter_id] = params
ret = {
'parameter_id': parameter_id,
'parameter_source': 'customized' if customized else 'algorithm',
'parameters': params
}
if trial_job_id is not None:
ret['trial_job_id'] = trial_job_id
if parameter_index is not None:
ret['parameter_index'] = parameter_index
else:
ret['parameter_index'] = 0
return json_tricks.dumps(ret)
class MultiPhaseMsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super(MultiPhaseMsgDispatcher, self).__init__()
self.tuner = tuner
self.assessor = assessor
if assessor is None:
_logger.debug('Assessor is not configured')
def load_checkpoint(self):
self.tuner.load_checkpoint()
if self.assessor is not None:
self.assessor.load_checkpoint()
def save_checkpoint(self):
self.tuner.save_checkpoint()
if self.assessor is not None:
self.assessor.save_checkpoint()
def handle_initialize(self, data):
'''
data is search space
'''
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
return True
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)]
params_list = self.tuner.generate_multiple_parameters(ids)
assert len(ids) == len(params_list)
for i, _ in enumerate(ids):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
return True
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
return True
def handle_import_data(self, data):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
return True
def handle_add_customized_trial(self, data):
# data: parameters
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
send(CommandType.NewTrialJob, _pack_parameter(id_, data, customized=True))
return True
def handle_report_metric_data(self, data):
trial_job_id = data['trial_job_id']
if data['type'] == 'FINAL':
id_ = data['parameter_id']
if id_ in _customized_parameter_ids:
self.tuner.receive_customized_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
else:
self.tuner.receive_trial_result(id_, _trial_params[id_], data['value'], trial_job_id)
elif data['type'] == 'PERIODICAL':
if self.assessor is not None:
self._handle_intermediate_metric_data(data)
else:
pass
elif data['type'] == 'REQUEST_PARAMETER':
assert data['trial_job_id'] is not None
assert data['parameter_index'] is not None
param_id = _create_parameter_id()
param = self.tuner.generate_parameters(param_id, trial_job_id)
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'], parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
return True
def handle_trial_end(self, data):
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id)
return True
def handle_import_data(self, data):
pass
def _handle_intermediate_metric_data(self, data):
if data['type'] != 'PERIODICAL':
return True
if self.assessor is None:
return True
trial_job_id = data['trial_job_id']
if trial_job_id in _ended_trials:
return True
history = _trial_history[trial_job_id]
history[data['sequence']] = data['value']
ordered_history = _sort_history(history)
if len(ordered_history) < data['sequence']: # no user-visible update since last time
return True
try:
result = self.assessor.assess_trial(trial_job_id, ordered_history)
except Exception as e:
_logger.exception('Assessor error')
if isinstance(result, bool):
result = AssessResult.Good if result else AssessResult.Bad
elif not isinstance(result, AssessResult):
msg = 'Result of Assessor.assess_trial must be an object of AssessResult, not %s'
raise RuntimeError(msg % type(result))
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
else:
_logger.debug('GOOD')
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
from nni.recoverable import Recoverable
_logger = logging.getLogger(__name__)
class MultiPhaseTuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: identifier of the parameter (int)
"""
raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int
"""
return [self.generate_parameters(parameter_id) for parameter_id in parameter_id_list]
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial reports its final result. Must override.
parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()'
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: identifier of the parameter (int)
parameters: object created by user
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success, trial_job_id):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def update_search_space(self, search_space):
"""Update the search space of tuner. Must override.
search_space: JSON object
"""
raise NotImplementedError('Tuner: update_search_space not implemented')
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def load_checkpoint(self):
"""Load the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def save_checkpoint(self):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
def _on_exit(self):
pass
def _on_error(self):
pass
def import_data(self, data):
pass
...@@ -123,7 +123,7 @@ class NetworkMorphismTuner(Tuner): ...@@ -123,7 +123,7 @@ class NetworkMorphismTuner(Tuner):
""" """
self.search_space = search_space self.search_space = search_space
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
""" """
Returns a set of trial neural architecture, as a serializable object. Returns a set of trial neural architecture, as a serializable object.
...@@ -152,7 +152,7 @@ class NetworkMorphismTuner(Tuner): ...@@ -152,7 +152,7 @@ class NetworkMorphismTuner(Tuner):
return json_out return json_out
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
""" Record an observation of the objective function. """ Record an observation of the objective function.
Parameters Parameters
......
...@@ -151,7 +151,7 @@ class SMACTuner(Tuner): ...@@ -151,7 +151,7 @@ class SMACTuner(Tuner):
else: else:
self.logger.warning('update search space is not supported.') self.logger.warning('update search space is not supported.')
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""receive_trial_result """receive_trial_result
Parameters Parameters
...@@ -209,7 +209,7 @@ class SMACTuner(Tuner): ...@@ -209,7 +209,7 @@ class SMACTuner(Tuner):
converted_dict[key] = value converted_dict[key] = value
return converted_dict return converted_dict
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
"""generate one instance of hyperparameters """generate one instance of hyperparameters
Parameters Parameters
...@@ -232,7 +232,7 @@ class SMACTuner(Tuner): ...@@ -232,7 +232,7 @@ class SMACTuner(Tuner):
self.total_data[parameter_id] = challenger self.total_data[parameter_id] = challenger
return self.convert_loguniform_categorical(challenger.get_dictionary()) return self.convert_loguniform_categorical(challenger.get_dictionary())
def generate_multiple_parameters(self, parameter_id_list): def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""generate mutiple instances of hyperparameters """generate mutiple instances of hyperparameters
Parameters Parameters
......
...@@ -30,14 +30,14 @@ _logger = logging.getLogger(__name__) ...@@ -30,14 +30,14 @@ _logger = logging.getLogger(__name__)
class Tuner(Recoverable): class Tuner(Recoverable):
# pylint: disable=no-self-use,unused-argument # pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial (hyper-)parameters, as a serializable object. """Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'. User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int parameter_id: int
""" """
raise NotImplementedError('Tuner: generate_parameters not implemented') raise NotImplementedError('Tuner: generate_parameters not implemented')
def generate_multiple_parameters(self, parameter_id_list): def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects. """Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Call 'generate_parameters()' by 'count' times by default. Call 'generate_parameters()' by 'count' times by default.
User code must override either this function or 'generate_parameters()'. User code must override either this function or 'generate_parameters()'.
...@@ -49,13 +49,13 @@ class Tuner(Recoverable): ...@@ -49,13 +49,13 @@ class Tuner(Recoverable):
for parameter_id in parameter_id_list: for parameter_id in parameter_id_list:
try: try:
_logger.debug("generating param for {}".format(parameter_id)) _logger.debug("generating param for {}".format(parameter_id))
res = self.generate_parameters(parameter_id) res = self.generate_parameters(parameter_id, **kwargs)
except nni.NoMoreTrialError: except nni.NoMoreTrialError:
return result return result
result.append(res) result.append(res)
return result return result
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Invoked when a trial reports its final result. Must override. """Invoked when a trial reports its final result. Must override.
parameter_id: int parameter_id: int
parameters: object created by 'generate_parameters()' parameters: object created by 'generate_parameters()'
...@@ -63,7 +63,7 @@ class Tuner(Recoverable): ...@@ -63,7 +63,7 @@ class Tuner(Recoverable):
""" """
raise NotImplementedError('Tuner: receive_trial_result not implemented') raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value): def receive_customized_trial_result(self, parameter_id, parameters, value, **kwargs):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default. """Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int parameter_id: int
parameters: object created by user parameters: object created by user
...@@ -71,7 +71,7 @@ class Tuner(Recoverable): ...@@ -71,7 +71,7 @@ class Tuner(Recoverable):
""" """
_logger.info('Customized trial job %s ignored by tuner', parameter_id) _logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success): def trial_end(self, parameter_id, success, **kwargs):
"""Invoked when a trial is completed or terminated. Do nothing by default. """Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: int parameter_id: int
success: True if the trial successfully completed; False if failed or terminated success: True if the trial successfully completed; False if failed or terminated
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging
import random
from io import BytesIO
import nni
import nni.protocol
from nni.protocol import CommandType, send, receive
from nni.multi_phase.multi_phase_tuner import MultiPhaseTuner
from nni.multi_phase.multi_phase_dispatcher import MultiPhaseMsgDispatcher
from unittest import TestCase, main
class NaiveMultiPhaseTuner(MultiPhaseTuner):
'''
supports only choices
'''
def __init__(self):
self.search_space = None
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
"""
generated_parameters = {}
if self.search_space is None:
raise AssertionError('Search space not specified')
for k in self.search_space:
param = self.search_space[k]
if not param['_type'] == 'choice':
raise ValueError('Only choice type is supported')
param_values = param['_value']
generated_parameters[k] = param_values[random.randint(0, len(param_values)-1)]
logging.getLogger(__name__).debug(generated_parameters)
return generated_parameters
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
logging.getLogger(__name__).debug('receive_trial_result: {},{},{},{}'.format(parameter_id, parameters, value, trial_job_id))
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
pass
def update_search_space(self, search_space):
self.search_space = search_space
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._in_file = _in_buf
nni.protocol._out_file = _out_buf
def _test_tuner():
_reverse_io() # now we are sending to Tuner's incoming stream
send(CommandType.UpdateSearchSpace, "{\"learning_rate\": {\"_value\": [0.0001, 0.001, 0.002, 0.005, 0.01], \"_type\": \"choice\"}, \"optimizer\": {\"_value\": [\"Adam\", \"SGD\"], \"_type\": \"choice\"}}")
send(CommandType.RequestTrialJobs, '2')
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10,"trial_job_id":"abc"}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11,"trial_job_id":"abc"}')
send(CommandType.AddCustomizedTrialJob, '{"param":-1}')
send(CommandType.ReportMetricData, '{"parameter_id":2,"type":"FINAL","value":22,"trial_job_id":"abc"}')
send(CommandType.RequestTrialJobs, '1')
send(CommandType.TrialEnd, '{"trial_job_id":"abc"}')
_restore_io()
tuner = NaiveMultiPhaseTuner()
dispatcher = MultiPhaseMsgDispatcher(tuner)
dispatcher.run()
_reverse_io() # now we are receiving from Tuner's outgoing stream
command, data = receive() # this one is customized
print(command, data)
class MultiPhaseTestCase(TestCase):
def test_tuner(self):
_test_tuner()
if __name__ == '__main__':
main()
\ No newline at end of file
...@@ -35,7 +35,7 @@ class NaiveTuner(Tuner): ...@@ -35,7 +35,7 @@ class NaiveTuner(Tuner):
self.trial_results = [ ] self.trial_results = [ ]
self.search_space = None self.search_space = None
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id, **kwargs):
# report Tuner's internal states to generated parameters, # report Tuner's internal states to generated parameters,
# so we don't need to pause the main loop # so we don't need to pause the main loop
self.param += 2 self.param += 2
...@@ -45,7 +45,7 @@ class NaiveTuner(Tuner): ...@@ -45,7 +45,7 @@ class NaiveTuner(Tuner):
'search_space': self.search_space 'search_space': self.search_space
} }
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
reward = extract_scalar_reward(value) reward = extract_scalar_reward(value)
self.trial_results.append((parameter_id, parameters['param'], reward, False)) self.trial_results.append((parameter_id, parameters['param'], reward, False))
...@@ -103,11 +103,9 @@ class TunerTestCase(TestCase): ...@@ -103,11 +103,9 @@ class TunerTestCase(TestCase):
command, data = receive() # this one is customized command, data = receive() # this one is customized
data = json.loads(data) data = json.loads(data)
self.assertIs(command, CommandType.NewTrialJob) self.assertIs(command, CommandType.NewTrialJob)
self.assertEqual(data, { self.assertEqual(data['parameter_id'], 2)
'parameter_id': 2, self.assertEqual(data['parameter_source'], 'customized')
'parameter_source': 'customized', self.assertEqual(data['parameters'], { 'param': -1 })
'parameters': { 'param': -1 }
})
self._assert_params(3, 6, [[1,4,11,False], [2,-1,22,True]], {'name':'SS0'}) self._assert_params(3, 6, [[1,4,11,False], [2,-1,22,True]], {'name':'SS0'})
......
import * as React from 'react';
import { Row, Modal } from 'antd';
import ReactEcharts from 'echarts-for-react';
import IntermediateVal from '../public-child/IntermediateVal';
import '../../static/style/compare.scss';
import { TableObj, Intermedia, TooltipForIntermediate } from 'src/static/interface';
// the modal of trial compare
interface CompareProps {
compareRows: Array<TableObj>;
visible: boolean;
cancelFunc: () => void;
}
class Compare extends React.Component<CompareProps, {}> {
public _isCompareMount: boolean;
constructor(props: CompareProps) {
super(props);
}
intermediate = () => {
const { compareRows } = this.props;
const trialIntermediate: Array<Intermedia> = [];
const idsList: Array<string> = [];
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
trialIntermediate.push({
name: temp.id,
data: temp.description.intermediate,
type: 'line',
hyperPara: temp.description.parameters
});
idsList.push(temp.id);
});
// find max intermediate number
trialIntermediate.sort((a, b) => { return (b.data.length - a.data.length); });
const legend: Array<string> = [];
// max length
const length = trialIntermediate[0] !== undefined ? trialIntermediate[0].data.length : 0;
const xAxis: Array<number> = [];
Object.keys(trialIntermediate).map(item => {
const temp = trialIntermediate[item];
legend.push(temp.name);
});
for (let i = 1; i <= length; i++) {
xAxis.push(i);
}
const option = {
tooltip: {
trigger: 'item',
enterable: true,
position: function (point: Array<number>, data: TooltipForIntermediate) {
if (data.dataIndex < length / 2) {
return [point[0], 80];
} else {
return [point[0] - 300, 80];
}
},
formatter: function (data: TooltipForIntermediate) {
const trialId = data.seriesName;
let obj = {};
const temp = trialIntermediate.find(key => key.name === trialId);
if (temp !== undefined) {
obj = temp.hyperPara;
}
return '<div class="tooldetailAccuracy">' +
'<div>Trial ID: ' + trialId + '</div>' +
'<div>Intermediate: ' + data.data + '</div>' +
'<div>Parameters: ' +
'<pre>' + JSON.stringify(obj, null, 4) + '</pre>' +
'</div>' +
'</div>';
}
},
grid: {
left: '5%',
top: 40,
containLabel: true
},
legend: {
data: idsList
},
xAxis: {
type: 'category',
name: 'Step',
boundaryGap: false,
data: xAxis
},
yAxis: {
type: 'value',
name: 'metric'
},
series: trialIntermediate
};
return (
<ReactEcharts
option={option}
style={{ width: '100%', height: 418, margin: '0 auto' }}
notMerge={true} // update now
/>
);
}
// render table column ---
initColumn = () => {
const { compareRows } = this.props;
const idList: Array<string> = [];
const durationList: Array<number> = [];
const parameterList: Array<object> = [];
let parameterKeys: Array<string> = [];
if (compareRows.length !== 0) {
parameterKeys = Object.keys(compareRows[0].description.parameters);
}
Object.keys(compareRows).map(item => {
const temp = compareRows[item];
idList.push(temp.id);
durationList.push(temp.duration);
parameterList.push(temp.description.parameters);
});
return (
<table className="compare">
<tbody>
<tr>
<td />
{Object.keys(idList).map(key => {
return (
<td className="value" key={key}>{idList[key]}</td>
);
})}
</tr>
<tr>
<td className="column">Default metric</td>
{Object.keys(compareRows).map(index => {
const temp = compareRows[index];
return (
<td className="value" key={index}>
<IntermediateVal record={temp}/>
</td>
);
})}
</tr>
<tr>
<td className="column">duration</td>
{Object.keys(durationList).map(index => {
return (
<td className="value" key={index}>{durationList[index]}</td>
);
})}
</tr>
{
Object.keys(parameterKeys).map(index => {
return (
<tr key={index}>
<td className="column" key={index}>{parameterKeys[index]}</td>
{
Object.keys(parameterList).map(key => {
return (
<td key={key} className="value">
{parameterList[key][parameterKeys[index]]}
</td>
);
})
}
</tr>
);
})
}
</tbody>
</table>
);
}
componentDidMount() {
this._isCompareMount = true;
}
componentWillUnmount() {
this._isCompareMount = false;
}
render() {
const { visible, cancelFunc } = this.props;
return (
<Modal
title="Compare trials"
visible={visible}
onCancel={cancelFunc}
footer={null}
destroyOnClose={true}
maskClosable={false}
width="90%"
>
<Row>{this.intermediate()}</Row>
<Row>{this.initColumn()}</Row>
</Modal>
);
}
}
export default Compare;
...@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -378,7 +378,7 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
{/* trial table list */} {/* trial table list */}
<Title1 text="Trial jobs" icon="6.png" /> <Title1 text="Trial jobs" icon="6.png" />
<Row className="allList"> <Row className="allList">
<Col span={12}> <Col span={10}>
<span>Show</span> <span>Show</span>
<Select <Select
className="entry" className="entry"
...@@ -392,26 +392,28 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -392,26 +392,28 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
</Select> </Select>
<span>entries</span> <span>entries</span>
</Col> </Col>
<Col span={12} className="right"> <Col span={14} className="right">
<Row> <Button
<Col span={12}> type="primary"
<Button className="tableButton editStyle"
type="primary" onClick={this.tableList ? this.tableList.addColumn : this.test}
className="tableButton editStyle" >
onClick={this.tableList ? this.tableList.addColumn : this.test} Add column
> </Button>
Add column <Button
</Button> type="primary"
</Col> className="tableButton editStyle mediateBtn"
<Col span={12}> // use child-component tableList's function, the function is in child-component.
<Input onClick={this.tableList ? this.tableList.compareBtn : this.test}
type="text" >
placeholder="Search by id, trial No. or status" Compare
onChange={this.searchTrial} </Button>
style={{ width: 230, marginLeft: 6 }} <Input
/> type="text"
</Col> placeholder="Search by id, trial No. or status"
</Row> onChange={this.searchTrial}
style={{ width: 230, marginLeft: 6 }}
/>
</Col> </Col>
</Row> </Row>
<TableList <TableList
......
import * as React from 'react'; import * as React from 'react';
import { Row, Col, Button, Switch } from 'antd'; import { Row, Col, Button, Switch } from 'antd';
import { TooltipForIntermediate, TableObj } from '../../static/interface'; import { TooltipForIntermediate, TableObj, Intermedia } from '../../static/interface';
import ReactEcharts from 'echarts-for-react'; import ReactEcharts from 'echarts-for-react';
require('echarts/lib/component/tooltip'); require('echarts/lib/component/tooltip');
require('echarts/lib/component/title'); require('echarts/lib/component/title');
interface Intermedia {
name: string; // id
type: string;
data: Array<number | object>; // intermediate data
hyperPara: object; // each trial hyperpara value
}
interface IntermediateState { interface IntermediateState {
detailSource: Array<TableObj>; detailSource: Array<TableObj>;
interSource: object; interSource: object;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment