Unverified Commit 7c4b8c0d authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Make pylint happy (#1649)

Update python sdk and nni_annotation to pass pylint rules
parent 22316800
...@@ -69,3 +69,6 @@ build ...@@ -69,3 +69,6 @@ build
*.egg-info *.egg-info
.vscode .vscode
# In case you place source code in ~/nni/
/experiments
...@@ -19,16 +19,14 @@ ...@@ -19,16 +19,14 @@
# ================================================================================================== # ==================================================================================================
# pylint: disable=wildcard-import
from .trial import * from .trial import *
from .smartparam import * from .smartparam import *
from .nas_utils import training_update from .nas_utils import training_update
class NoMoreTrialError(Exception): class NoMoreTrialError(Exception):
def __init__(self,ErrorInfo): def __init__(self, ErrorInfo):
super().__init__(self) super().__init__(self)
self.errorinfo=ErrorInfo self.errorinfo = ErrorInfo
def __str__(self): def __str__(self):
return self.errorinfo return self.errorinfo
\ No newline at end of file
...@@ -27,9 +27,10 @@ import logging ...@@ -27,9 +27,10 @@ import logging
import json import json
import importlib import importlib
from .common import enable_multi_thread, enable_multi_phase
from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName from .constants import ModuleName, ClassName, ClassArgs, AdvisorModuleName, AdvisorClassName
from nni.common import enable_multi_thread, enable_multi_phase from .msg_dispatcher import MsgDispatcher
from nni.msg_dispatcher import MsgDispatcher
logger = logging.getLogger('nni.main') logger = logging.getLogger('nni.main')
logger.debug('START') logger.debug('START')
...@@ -44,7 +45,7 @@ def augment_classargs(input_class_args, classname): ...@@ -44,7 +45,7 @@ def augment_classargs(input_class_args, classname):
input_class_args[key] = value input_class_args[key] = value
return input_class_args return input_class_args
def create_builtin_class_instance(classname, jsonstr_args, is_advisor = False): def create_builtin_class_instance(classname, jsonstr_args, is_advisor=False):
if is_advisor: if is_advisor:
if classname not in AdvisorModuleName or \ if classname not in AdvisorModuleName or \
importlib.util.find_spec(AdvisorModuleName[classname]) is None: importlib.util.find_spec(AdvisorModuleName[classname]) is None:
...@@ -130,55 +131,15 @@ def main(): ...@@ -130,55 +131,15 @@ def main():
if args.advisor_class_name: if args.advisor_class_name:
# advisor is enabled and starts to run # advisor is enabled and starts to run
if args.advisor_class_name in AdvisorModuleName: _run_advisor(args)
dispatcher = create_builtin_class_instance(
args.advisor_class_name,
args.advisor_args, True)
else:
dispatcher = create_customized_class_instance(
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
if dispatcher is None:
raise AssertionError('Failed to create Advisor instance')
try:
dispatcher.run()
except Exception as exception:
logger.exception(exception)
raise
else: else:
# tuner (and assessor) is enabled and starts to run # tuner (and assessor) is enabled and starts to run
tuner = None tuner = _create_tuner(args)
assessor = None
if args.tuner_class_name in ModuleName:
tuner = create_builtin_class_instance(
args.tuner_class_name,
args.tuner_args)
else:
tuner = create_customized_class_instance(
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
if args.assessor_class_name: if args.assessor_class_name:
if args.assessor_class_name in ModuleName: assessor = _create_assessor(args)
assessor = create_builtin_class_instance( else:
args.assessor_class_name, assessor = None
args.assessor_args)
else:
assessor = create_customized_class_instance(
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
dispatcher = MsgDispatcher(tuner, assessor) dispatcher = MsgDispatcher(tuner, assessor)
try: try:
...@@ -193,6 +154,59 @@ def main(): ...@@ -193,6 +154,59 @@ def main():
assessor._on_error() assessor._on_error()
raise raise
def _run_advisor(args):
if args.advisor_class_name in AdvisorModuleName:
dispatcher = create_builtin_class_instance(
args.advisor_class_name,
args.advisor_args, True)
else:
dispatcher = create_customized_class_instance(
args.advisor_directory,
args.advisor_class_filename,
args.advisor_class_name,
args.advisor_args)
if dispatcher is None:
raise AssertionError('Failed to create Advisor instance')
try:
dispatcher.run()
except Exception as exception:
logger.exception(exception)
raise
def _create_tuner(args):
if args.tuner_class_name in ModuleName:
tuner = create_builtin_class_instance(
args.tuner_class_name,
args.tuner_args)
else:
tuner = create_customized_class_instance(
args.tuner_directory,
args.tuner_class_filename,
args.tuner_class_name,
args.tuner_args)
if tuner is None:
raise AssertionError('Failed to create Tuner instance')
return tuner
def _create_assessor(args):
if args.assessor_class_name in ModuleName:
assessor = create_builtin_class_instance(
args.assessor_class_name,
args.assessor_args)
else:
assessor = create_customized_class_instance(
args.assessor_directory,
args.assessor_class_filename,
args.assessor_class_name,
args.assessor_args)
if assessor is None:
raise AssertionError('Failed to create Assessor instance')
return assessor
if __name__ == '__main__': if __name__ == '__main__':
try: try:
main() main()
......
...@@ -31,7 +31,6 @@ class AssessResult(Enum): ...@@ -31,7 +31,6 @@ class AssessResult(Enum):
Bad = False Bad = False
class Assessor(Recoverable): class Assessor(Recoverable):
# pylint: disable=no-self-use,unused-argument
def assess_trial(self, trial_job_id, trial_history): def assess_trial(self, trial_job_id, trial_history):
"""Determines whether a trial should be killed. Must override. """Determines whether a trial should be killed. Must override.
...@@ -46,21 +45,20 @@ class Assessor(Recoverable): ...@@ -46,21 +45,20 @@ class Assessor(Recoverable):
trial_job_id: identifier of the trial (str). trial_job_id: identifier of the trial (str).
success: True if the trial successfully completed; False if failed or terminated. success: True if the trial successfully completed; False if failed or terminated.
""" """
pass
def load_checkpoint(self): def load_checkpoint(self):
"""Load the checkpoint of assessr. """Load the checkpoint of assessr.
path: checkpoint directory for assessor path: checkpoint directory for assessor
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by assessor, checkpoint path: %s' % checkpoin_path) _logger.info('Load checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self): def save_checkpoint(self):
"""Save the checkpoint of assessor. """Save the checkpoint of assessor.
path: checkpoint directory for assessor path: checkpoint directory for assessor
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by assessor, checkpoint path: %s' % checkpoin_path) _logger.info('Save checkpoint ignored by assessor, checkpoint path: %s', checkpoin_path)
def _on_exit(self): def _on_exit(self):
pass pass
......
...@@ -100,7 +100,7 @@ class BatchTuner(Tuner): ...@@ -100,7 +100,7 @@ class BatchTuner(Tuner):
data: data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
""" """
if len(self.values) == 0: if not self.values:
logger.info("Search space has not been initialized, skip this data import") logger.info("Search space has not been initialized, skip this data import")
return return
......
...@@ -51,7 +51,7 @@ def create_parameter_id(): ...@@ -51,7 +51,7 @@ def create_parameter_id():
int int
parameter id parameter id
""" """
global _next_parameter_id # pylint: disable=global-statement global _next_parameter_id
_next_parameter_id += 1 _next_parameter_id += 1
return _next_parameter_id - 1 return _next_parameter_id - 1
...@@ -80,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- ...@@ -80,7 +80,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
return params_id return params_id
class Bracket(object): class Bracket:
""" """
A bracket in BOHB, all the information of a bracket is managed by A bracket in BOHB, all the information of a bracket is managed by
an instance of this class. an instance of this class.
...@@ -98,7 +98,7 @@ class Bracket(object): ...@@ -98,7 +98,7 @@ class Bracket(object):
max_budget : float max_budget : float
The largest budget to consider. Needs to be larger than min_budget! The largest budget to consider. Needs to be larger than min_budget!
The budgets will be geometrically distributed The budgets will be geometrically distributed
:math:`a^2 + b^2 = c^2 \sim \eta^k` for :math:`k\in [0, 1, ... , num\_subsets - 1]`. :math:`a^2 + b^2 = c^2 \\sim \\eta^k` for :math:`k\\in [0, 1, ... , num\\_subsets - 1]`.
optimize_mode: str optimize_mode: str
optimize mode, 'maximize' or 'minimize' optimize mode, 'maximize' or 'minimize'
""" """
...@@ -169,7 +169,7 @@ class Bracket(object): ...@@ -169,7 +169,7 @@ class Bracket(object):
If we have generated new trials after this trial end, we will return a new trial parameters. If we have generated new trials after this trial end, we will return a new trial parameters.
Otherwise, we will return None. Otherwise, we will return None.
""" """
global _KEY # pylint: disable=global-statement global _KEY
self.num_finished_configs[i] += 1 self.num_finished_configs[i] += 1
logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d',
self.s, self.i, i, self.num_finished_configs[i], self.num_configs_to_run[i]) self.s, self.i, i, self.num_finished_configs[i], self.num_configs_to_run[i])
...@@ -377,8 +377,10 @@ class BOHB(MsgDispatcherBase): ...@@ -377,8 +377,10 @@ class BOHB(MsgDispatcherBase):
if self.curr_s < 0: if self.curr_s < 0:
logger.info("s < 0, Finish this round of Hyperband in BOHB. Generate new round") logger.info("s < 0, Finish this round of Hyperband in BOHB. Generate new round")
self.curr_s = self.s_max self.curr_s = self.s_max
self.brackets[self.curr_s] = Bracket(s=self.curr_s, s_max=self.s_max, eta=self.eta, self.brackets[self.curr_s] = Bracket(
max_budget=self.max_budget, optimize_mode=self.optimize_mode) s=self.curr_s, s_max=self.s_max, eta=self.eta,
max_budget=self.max_budget, optimize_mode=self.optimize_mode
)
next_n, next_r = self.brackets[self.curr_s].get_n_r() next_n, next_r = self.brackets[self.curr_s].get_n_r()
logger.debug( logger.debug(
'new SuccessiveHalving iteration, next_n=%d, next_r=%d', next_n, next_r) 'new SuccessiveHalving iteration, next_n=%d, next_r=%d', next_n, next_r)
...@@ -599,7 +601,7 @@ class BOHB(MsgDispatcherBase): ...@@ -599,7 +601,7 @@ class BOHB(MsgDispatcherBase):
logger.debug('bracket id = %s, metrics value = %s, type = %s', s, value, data['type']) logger.debug('bracket id = %s, metrics value = %s, type = %s', s, value, data['type'])
s = int(s) s = int(s)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here, # add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet. # because when the first parameter_id is created, trial_job_id is not known yet.
if data['trial_job_id'] in self.job_id_para_id_map: if data['trial_job_id'] in self.job_id_para_id_map:
assert self.job_id_para_id_map[data['trial_job_id']] == data['parameter_id'] assert self.job_id_para_id_map[data['trial_job_id']] == data['parameter_id']
...@@ -643,14 +645,14 @@ class BOHB(MsgDispatcherBase): ...@@ -643,14 +645,14 @@ class BOHB(MsgDispatcherBase):
""" """
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %(_completed_num, len(data))) logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
_completed_num += 1 _completed_num += 1
assert "parameter" in trial_info assert "parameter" in trial_info
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value: if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value) logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue continue
budget_exist_flag = False budget_exist_flag = False
barely_params = dict() barely_params = dict()
...@@ -662,7 +664,7 @@ class BOHB(MsgDispatcherBase): ...@@ -662,7 +664,7 @@ class BOHB(MsgDispatcherBase):
barely_params[keys] = _params[keys] barely_params[keys] = _params[keys]
if not budget_exist_flag: if not budget_exist_flag:
_budget = self.max_budget _budget = self.max_budget
logger.info("Set \"TRIAL_BUDGET\" value to %s (max budget)" %self.max_budget) logger.info("Set \"TRIAL_BUDGET\" value to %s (max budget)", self.max_budget)
if self.optimize_mode is OptimizeMode.Maximize: if self.optimize_mode is OptimizeMode.Maximize:
reward = -_value reward = -_value
else: else:
......
...@@ -28,7 +28,6 @@ ...@@ -28,7 +28,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import logging import logging
import traceback
import ConfigSpace import ConfigSpace
import ConfigSpace.hyperparameters import ConfigSpace.hyperparameters
...@@ -39,7 +38,7 @@ import statsmodels.api as sm ...@@ -39,7 +38,7 @@ import statsmodels.api as sm
logger = logging.getLogger('BOHB_Advisor') logger = logging.getLogger('BOHB_Advisor')
class CG_BOHB(object): class CG_BOHB:
def __init__(self, configspace, min_points_in_model=None, def __init__(self, configspace, min_points_in_model=None,
top_n_percent=15, num_samples=64, random_fraction=1/3, top_n_percent=15, num_samples=64, random_fraction=1/3,
bandwidth_factor=3, min_bandwidth=1e-3): bandwidth_factor=3, min_bandwidth=1e-3):
...@@ -77,8 +76,8 @@ class CG_BOHB(object): ...@@ -77,8 +76,8 @@ class CG_BOHB(object):
self.min_points_in_model = len(self.configspace.get_hyperparameters())+1 self.min_points_in_model = len(self.configspace.get_hyperparameters())+1
if self.min_points_in_model < len(self.configspace.get_hyperparameters())+1: if self.min_points_in_model < len(self.configspace.get_hyperparameters())+1:
logger.warning('Invalid min_points_in_model value. Setting it to %i'%(len(self.configspace.get_hyperparameters())+1)) logger.warning('Invalid min_points_in_model value. Setting it to %i', len(self.configspace.get_hyperparameters()) + 1)
self.min_points_in_model =len(self.configspace.get_hyperparameters())+1 self.min_points_in_model = len(self.configspace.get_hyperparameters()) + 1
self.num_samples = num_samples self.num_samples = num_samples
self.random_fraction = random_fraction self.random_fraction = random_fraction
...@@ -107,9 +106,9 @@ class CG_BOHB(object): ...@@ -107,9 +106,9 @@ class CG_BOHB(object):
self.kde_models = dict() self.kde_models = dict()
def largest_budget_with_model(self): def largest_budget_with_model(self):
if len(self.kde_models) == 0: if not self.kde_models:
return(-float('inf')) return -float('inf')
return(max(self.kde_models.keys())) return max(self.kde_models.keys())
def sample_from_largest_budget(self, info_dict): def sample_from_largest_budget(self, info_dict):
"""We opted for a single multidimensional KDE compared to the """We opted for a single multidimensional KDE compared to the
...@@ -162,11 +161,11 @@ class CG_BOHB(object): ...@@ -162,11 +161,11 @@ class CG_BOHB(object):
val = minimize_me(vector) val = minimize_me(vector)
if not np.isfinite(val): if not np.isfinite(val):
logger.warning('sampled vector: %s has EI value %s'%(vector, val)) logger.warning('sampled vector: %s has EI value %s', vector, val)
logger.warning("data in the KDEs:\n%s\n%s"%(kde_good.data, kde_bad.data)) logger.warning("data in the KDEs:\n%s\n%s", kde_good.data, kde_bad.data)
logger.warning("bandwidth of the KDEs:\n%s\n%s"%(kde_good.bw, kde_bad.bw)) logger.warning("bandwidth of the KDEs:\n%s\n%s", kde_good.bw, kde_bad.bw)
logger.warning("l(x) = %s"%(l(vector))) logger.warning("l(x) = %s", l(vector))
logger.warning("g(x) = %s"%(g(vector))) logger.warning("g(x) = %s", g(vector))
# right now, this happens because a KDE does not contain all values for a categorical parameter # right now, this happens because a KDE does not contain all values for a categorical parameter
# this cannot be fixed with the statsmodels KDE, so for now, we are just going to evaluate this one # this cannot be fixed with the statsmodels KDE, so for now, we are just going to evaluate this one
...@@ -181,19 +180,15 @@ class CG_BOHB(object): ...@@ -181,19 +180,15 @@ class CG_BOHB(object):
best_vector = vector best_vector = vector
if best_vector is None: if best_vector is None:
logger.debug("Sampling based optimization with %i samples failed -> using random configuration"%self.num_samples) logger.debug("Sampling based optimization with %i samples failed -> using random configuration", self.num_samples)
sample = self.configspace.sample_configuration().get_dictionary() sample = self.configspace.sample_configuration().get_dictionary()
info_dict['model_based_pick'] = False info_dict['model_based_pick'] = False
else: else:
logger.debug('best_vector: {}, {}, {}, {}'.format(best_vector, best, l(best_vector), g(best_vector))) logger.debug('best_vector: %s, %s, %s, %s', best_vector, best, l(best_vector), g(best_vector))
for i, hp_value in enumerate(best_vector): for i, _ in enumerate(best_vector):
if isinstance( hp = self.configspace.get_hyperparameter(self.configspace.get_hyperparameter_by_idx(i))
self.configspace.get_hyperparameter( if isinstance(hp, ConfigSpace.hyperparameters.CategoricalHyperparameter):
self.configspace.get_hyperparameter_by_idx(i)
),
ConfigSpace.hyperparameters.CategoricalHyperparameter
):
best_vector[i] = int(np.rint(best_vector[i])) best_vector[i] = int(np.rint(best_vector[i]))
sample = ConfigSpace.Configuration(self.configspace, vector=best_vector).get_dictionary() sample = ConfigSpace.Configuration(self.configspace, vector=best_vector).get_dictionary()
...@@ -224,12 +219,12 @@ class CG_BOHB(object): ...@@ -224,12 +219,12 @@ class CG_BOHB(object):
# If no model is available, sample from prior # If no model is available, sample from prior
# also mix in a fraction of random configs # also mix in a fraction of random configs
if len(self.kde_models.keys()) == 0 or np.random.rand() < self.random_fraction: if not self.kde_models.keys() or np.random.rand() < self.random_fraction:
sample = self.configspace.sample_configuration() sample = self.configspace.sample_configuration()
info_dict['model_based_pick'] = False info_dict['model_based_pick'] = False
if sample is None: if sample is None:
sample, info_dict= self.sample_from_largest_budget(info_dict) sample, info_dict = self.sample_from_largest_budget(info_dict)
sample = ConfigSpace.util.deactivate_inactive_hyperparameters( sample = ConfigSpace.util.deactivate_inactive_hyperparameters(
configuration_space=self.configspace, configuration_space=self.configspace,
...@@ -245,10 +240,10 @@ class CG_BOHB(object): ...@@ -245,10 +240,10 @@ class CG_BOHB(object):
for i in range(array.shape[0]): for i in range(array.shape[0]):
datum = np.copy(array[i]) datum = np.copy(array[i])
nan_indices = np.argwhere(np.isnan(datum)).flatten() nan_indices = np.argwhere(np.isnan(datum)).flatten()
while(np.any(nan_indices)): while np.any(nan_indices):
nan_idx = nan_indices[0] nan_idx = nan_indices[0]
valid_indices = np.argwhere(np.isfinite(array[:,nan_idx])).flatten() valid_indices = np.argwhere(np.isfinite(array[:, nan_idx])).flatten()
if len(valid_indices) > 0: if valid_indices:
# pick one of them at random and overwrite all NaN values # pick one of them at random and overwrite all NaN values
row_idx = np.random.choice(valid_indices) row_idx = np.random.choice(valid_indices)
datum[nan_indices] = array[row_idx, nan_indices] datum[nan_indices] = array[row_idx, nan_indices]
...@@ -260,8 +255,8 @@ class CG_BOHB(object): ...@@ -260,8 +255,8 @@ class CG_BOHB(object):
else: else:
datum[nan_idx] = np.random.randint(t) datum[nan_idx] = np.random.randint(t)
nan_indices = np.argwhere(np.isnan(datum)).flatten() nan_indices = np.argwhere(np.isnan(datum)).flatten()
return_array[i,:] = datum return_array[i, :] = datum
return(return_array) return return_array
def new_result(self, loss, budget, parameters, update_model=True): def new_result(self, loss, budget, parameters, update_model=True):
""" """
...@@ -305,7 +300,7 @@ class CG_BOHB(object): ...@@ -305,7 +300,7 @@ class CG_BOHB(object):
# a) if not enough points are available # a) if not enough points are available
if len(self.configs[budget]) <= self.min_points_in_model - 1: if len(self.configs[budget]) <= self.min_points_in_model - 1:
logger.debug("Only %i run(s) for budget %f available, need more than %s \ logger.debug("Only %i run(s) for budget %f available, need more than %s \
-> can't build model!"%(len(self.configs[budget]), budget, self.min_points_in_model+1)) -> can't build model!", len(self.configs[budget]), budget, self.min_points_in_model+1)
return return
# b) during warnm starting when we feed previous results in and only update once # b) during warnm starting when we feed previous results in and only update once
if not update_model: if not update_model:
...@@ -345,5 +340,5 @@ class CG_BOHB(object): ...@@ -345,5 +340,5 @@ class CG_BOHB(object):
} }
# update probs for the categorical parameters for later sampling # update probs for the categorical parameters for later sampling
logger.debug('done building a new model for budget %f based on %i/%i split\nBest loss for this budget:%f\n' logger.debug('done building a new model for budget %f based on %i/%i split\nBest loss for this budget:%f\n',
%(budget, n_good, n_bad, np.min(train_losses))) budget, n_good, n_bad, np.min(train_losses))
...@@ -41,8 +41,8 @@ class _LoggerFileWrapper(TextIOBase): ...@@ -41,8 +41,8 @@ class _LoggerFileWrapper(TextIOBase):
def write(self, s): def write(self, s):
if s != '\n': if s != '\n':
time = datetime.now().strftime(_time_format) cur_time = datetime.now().strftime(_time_format)
self.file.write('[{}] PRINT '.format(time) + s + '\n') self.file.write('[{}] PRINT '.format(cur_time) + s + '\n')
self.file.flush() self.file.flush()
return len(s) return len(s)
......
...@@ -29,7 +29,7 @@ class LevelPruner(Pruner): ...@@ -29,7 +29,7 @@ class LevelPruner(Pruner):
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude """An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity. weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
...@@ -92,5 +92,5 @@ class AGP_Pruner(Pruner): ...@@ -92,5 +92,5 @@ class AGP_Pruner(Pruner):
def update_epoch(self, epoch, sess): def update_epoch(self, epoch, sess):
sess.run(self.assign_handler) sess.run(self.assign_handler)
sess.run(tf.assign(self.now_epoch, int(epoch))) sess.run(tf.assign(self.now_epoch, int(epoch)))
for k in self.if_init_list.keys(): for k in self.if_init_list:
self.if_init_list[k] = True self.if_init_list[k] = True
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import tensorflow as tf import tensorflow as tf
from .compressor import Quantizer from .compressor import Quantizer
__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -12,7 +12,7 @@ class NaiveQuantizer(Quantizer): ...@@ -12,7 +12,7 @@ class NaiveQuantizer(Quantizer):
""" """
def __init__(self, config_list): def __init__(self, config_list):
super().__init__(config_list) super().__init__(config_list)
self.layer_scale = { } self.layer_scale = {}
def quantize_weight(self, weight, config, op_name, **kwargs): def quantize_weight(self, weight, config, op_name, **kwargs):
new_scale = tf.reduce_max(tf.abs(weight)) / 127 new_scale = tf.reduce_max(tf.abs(weight)) / 127
...@@ -33,17 +33,17 @@ class QAT_Quantizer(Quantizer): ...@@ -33,17 +33,17 @@ class QAT_Quantizer(Quantizer):
- q_bits - q_bits
""" """
super().__init__(config_list) super().__init__(config_list)
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, config, **kwargs):
a = tf.stop_gradient(tf.reduce_min(weight)) a = tf.stop_gradient(tf.reduce_min(weight))
b = tf.stop_gradient(tf.reduce_max(weight)) b = tf.stop_gradient(tf.reduce_max(weight))
n = tf.cast(2 ** config['q_bits'], tf.float32) n = tf.cast(2 ** config['q_bits'], tf.float32)
scale = b-a/(n-1) scale = b-a/(n-1)
# use gradient_override_map to change round to idetity for gradient # use gradient_override_map to change round to idetity for gradient
with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}): with tf.get_default_graph().gradient_override_map({'Round': 'Identity'}):
qw = tf.round((weight-a)/scale)*scale +a qw = tf.round((weight-a)/scale)*scale +a
return qw return qw
...@@ -58,7 +58,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -58,7 +58,7 @@ class DoReFaQuantizer(Quantizer):
- q_bits - q_bits
""" """
super().__init__(config_list) super().__init__(config_list)
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, config, **kwargs):
a = tf.math.tanh(weight) a = tf.math.tanh(weight)
b = a/(2*tf.reduce_max(tf.abs(weight))) + 0.5 b = a/(2*tf.reduce_max(tf.abs(weight))) + 0.5
......
import tensorflow as tf
import logging import logging
import tensorflow as tf
from . import default_layers from . import default_layers
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -51,17 +51,14 @@ class Compressor: ...@@ -51,17 +51,14 @@ class Compressor:
Compressors can optionally overload this method to do model-specific initialization. Compressors can optionally overload this method to do model-specific initialization.
It is guaranteed that only one model will be bound to each compressor instance. It is guaranteed that only one model will be bound to each compressor instance.
""" """
pass
def update_epoch(self, epoch, sess): def update_epoch(self, epoch, sess):
"""If user want to update mask every epoch, user can override this method """If user want to update mask every epoch, user can override this method
""" """
pass
def step(self, sess): def step(self, sess):
"""If user want to update mask every step, user can override this method """If user want to update mask every step, user can override this method
""" """
pass
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
...@@ -84,10 +81,9 @@ class Compressor: ...@@ -84,10 +81,9 @@ class Compressor:
class Pruner(Compressor): class Pruner(Compressor):
"""Abstract base TensorFlow pruner""" """
Abstract base TensorFlow pruner
def __init__(self, config_list): """
super().__init__(config_list)
def calc_mask(self, weight, config, op, op_type, op_name): def calc_mask(self, weight, config, op, op_type, op_name):
"""Pruners should overload this method to provide mask for weight tensors. """Pruners should overload this method to provide mask for weight tensors.
...@@ -105,7 +101,7 @@ class Pruner(Compressor): ...@@ -105,7 +101,7 @@ class Pruner(Compressor):
# not sure what will happen if the weight is calculated from other operations # not sure what will happen if the weight is calculated from other operations
weight_index = _detect_weight_index(layer) weight_index = _detect_weight_index(layer)
if weight_index is None: if weight_index is None:
_logger.warning('Failed to detect weight for layer {}'.format(layer.name)) _logger.warning('Failed to detect weight for layer %s', layer.name)
return return
weight_op = layer.op.inputs[weight_index].op weight_op = layer.op.inputs[weight_index].op
weight = weight_op.inputs[0] weight = weight_op.inputs[0]
...@@ -115,10 +111,9 @@ class Pruner(Compressor): ...@@ -115,10 +111,9 @@ class Pruner(Compressor):
class Quantizer(Compressor): class Quantizer(Compressor):
"""Abstract base TensorFlow quantizer""" """
Abstract base TensorFlow quantizer
def __init__(self, config_list): """
super().__init__(config_list)
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
raise NotImplementedError("Quantizer must overload quantize_weight()") raise NotImplementedError("Quantizer must overload quantize_weight()")
...@@ -126,7 +121,7 @@ class Quantizer(Compressor): ...@@ -126,7 +121,7 @@ class Quantizer(Compressor):
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
weight_index = _detect_weight_index(layer) weight_index = _detect_weight_index(layer)
if weight_index is None: if weight_index is None:
_logger.warning('Failed to detect weight for layer {}'.format(layer.name)) _logger.warning('Failed to detect weight for layer %s', layer.name)
return return
weight_op = layer.op.inputs[weight_index].op weight_op = layer.op.inputs[weight_index].op
weight = weight_op.inputs[0] weight = weight_op.inputs[0]
...@@ -138,7 +133,7 @@ def _detect_weight_index(layer): ...@@ -138,7 +133,7 @@ def _detect_weight_index(layer):
index = default_layers.op_weight_index.get(layer.type) index = default_layers.op_weight_index.get(layer.type)
if index is not None: if index is not None:
return index return index
weight_indices = [ i for i, op in enumerate(layer.op.inputs) if op.name.endswith('Variable/read') ] weight_indices = [i for i, op in enumerate(layer.op.inputs) if op.name.endswith('Variable/read')]
if len(weight_indices) == 1: if len(weight_indices) == 1:
return weight_indices[0] return weight_indices[0]
return None return None
...@@ -36,7 +36,7 @@ class LevelPruner(Pruner): ...@@ -36,7 +36,7 @@ class LevelPruner(Pruner):
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
"""An automated gradual pruning algorithm that prunes the smallest magnitude """An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity. weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
...@@ -102,5 +102,5 @@ class AGP_Pruner(Pruner): ...@@ -102,5 +102,5 @@ class AGP_Pruner(Pruner):
def update_epoch(self, epoch): def update_epoch(self, epoch):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for k in self.if_init_list.keys(): for k in self.if_init_list:
self.if_init_list[k] = True self.if_init_list[k] = True
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import torch import torch
from .compressor import Quantizer from .compressor import Quantizer
__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -62,7 +62,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -62,7 +62,7 @@ class DoReFaQuantizer(Quantizer):
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, config, **kwargs):
out = weight.tanh() out = weight.tanh()
out = out /( 2 * out.abs().max()) + 0.5 out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, config['q_bits']) out = self.quantize(out, config['q_bits'])
out = 2 * out -1 out = 2 * out -1
return out return out
......
import torch
import logging import logging
import torch
from . import default_layers from . import default_layers
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -43,17 +43,14 @@ class Compressor: ...@@ -43,17 +43,14 @@ class Compressor:
Users can optionally overload this method to do model-specific initialization. Users can optionally overload this method to do model-specific initialization.
It is guaranteed that only one model will be bound to each compressor instance. It is guaranteed that only one model will be bound to each compressor instance.
""" """
pass
def update_epoch(self, epoch): def update_epoch(self, epoch):
"""if user want to update model every epoch, user can override this method """if user want to update model every epoch, user can override this method
""" """
pass
def step(self): def step(self):
"""if user want to update model every step, user can override this method """if user want to update model every step, user can override this method
""" """
pass
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
raise NotImplementedError() raise NotImplementedError()
...@@ -75,10 +72,9 @@ class Compressor: ...@@ -75,10 +72,9 @@ class Compressor:
class Pruner(Compressor): class Pruner(Compressor):
"""Abstract base PyTorch pruner""" """
Abstract base PyTorch pruner
def __init__(self, config_list): """
super().__init__(config_list)
def calc_mask(self, weight, config, op, op_type, op_name): def calc_mask(self, weight, config, op, op_type, op_name):
"""Pruners should overload this method to provide mask for weight tensors. """Pruners should overload this method to provide mask for weight tensors.
...@@ -93,17 +89,17 @@ class Pruner(Compressor): ...@@ -93,17 +89,17 @@ class Pruner(Compressor):
# create a wrapper forward function to replace the original one # create a wrapper forward function to replace the original one
assert layer._forward is None, 'Each model can only be compressed once' assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module): if not _check_weight(layer.module):
_logger.warning('Module {} does not have parameter "weight"'.format(layer.name)) _logger.warning('Module %s does not have parameter "weight"', layer.name)
return return
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*input): def new_forward(*inputs):
# apply mask to weight # apply mask to weight
old_weight = layer.module.weight.data old_weight = layer.module.weight.data
mask = self.calc_mask(old_weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) mask = self.calc_mask(old_weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = old_weight.mul(mask) layer.module.weight.data = old_weight.mul(mask)
# calculate forward # calculate forward
ret = layer._forward(*input) ret = layer._forward(*inputs)
# recover original weight # recover original weight
layer.module.weight.data = old_weight layer.module.weight.data = old_weight
return ret return ret
...@@ -112,10 +108,9 @@ class Pruner(Compressor): ...@@ -112,10 +108,9 @@ class Pruner(Compressor):
class Quantizer(Compressor): class Quantizer(Compressor):
"""Base quantizer for pytorch quantizer""" """
Base quantizer for pytorch quantizer
def __init__(self, config_list): """
super().__init__(config_list)
def __call__(self, model): def __call__(self, model):
self.compress(model) self.compress(model)
...@@ -130,15 +125,15 @@ class Quantizer(Compressor): ...@@ -130,15 +125,15 @@ class Quantizer(Compressor):
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
assert layer._forward is None, 'Each model can only be compressed once' assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module): if not _check_weight(layer.module):
_logger.warning('Module {} does not have parameter "weight"'.format(layer.name)) _logger.warning('Module %s does not have parameter "weight"', layer.name)
return return
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*input): def new_forward(*inputs):
weight = layer.module.weight.data weight = layer.module.weight.data
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = new_weight layer.module.weight.data = new_weight
return layer._forward(*input) return layer._forward(*inputs)
layer.module.forward = new_forward layer.module.forward = new_forward
......
...@@ -71,4 +71,4 @@ AdvisorModuleName = { ...@@ -71,4 +71,4 @@ AdvisorModuleName = {
AdvisorClassName = { AdvisorClassName = {
'Hyperband': 'Hyperband', 'Hyperband': 'Hyperband',
'BOHB': 'BOHB' 'BOHB': 'BOHB'
} }
\ No newline at end of file
...@@ -50,7 +50,7 @@ class CurvefittingAssessor(Assessor): ...@@ -50,7 +50,7 @@ class CurvefittingAssessor(Assessor):
self.higher_better = False self.higher_better = False
else: else:
self.higher_better = True self.higher_better = True
logger.warning('unrecognized optimize_mode', optimize_mode) logger.warning('unrecognized optimize_mode %s', optimize_mode)
# Start forecasting when historical data reaches start step # Start forecasting when historical data reaches start step
self.start_step = start_step self.start_step = start_step
# Record the compared threshold # Record the compared threshold
...@@ -81,9 +81,9 @@ class CurvefittingAssessor(Assessor): ...@@ -81,9 +81,9 @@ class CurvefittingAssessor(Assessor):
else: else:
self.set_best_performance = True self.set_best_performance = True
self.completed_best_performance = self.trial_history[-1] self.completed_best_performance = self.trial_history[-1]
logger.info('Updated complted best performance, trial job id:', trial_job_id) logger.info('Updated complted best performance, trial job id: %s', trial_job_id)
else: else:
logger.info('No need to update, trial job id: ', trial_job_id) logger.info('No need to update, trial job id: %s', trial_job_id)
def assess_trial(self, trial_job_id, trial_history): def assess_trial(self, trial_job_id, trial_history):
"""assess whether a trial should be early stop by curve fitting algorithm """assess whether a trial should be early stop by curve fitting algorithm
...@@ -105,7 +105,7 @@ class CurvefittingAssessor(Assessor): ...@@ -105,7 +105,7 @@ class CurvefittingAssessor(Assessor):
Exception Exception
unrecognize exception in curvefitting_assessor unrecognize exception in curvefitting_assessor
""" """
self.trial_job_id = trial_job_id trial_job_id = trial_job_id
self.trial_history = trial_history self.trial_history = trial_history
if not self.set_best_performance: if not self.set_best_performance:
return AssessResult.Good return AssessResult.Good
...@@ -122,7 +122,7 @@ class CurvefittingAssessor(Assessor): ...@@ -122,7 +122,7 @@ class CurvefittingAssessor(Assessor):
# Predict the final result # Predict the final result
curvemodel = CurveModel(self.target_pos) curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(trial_history) predict_y = curvemodel.predict(trial_history)
logger.info('Prediction done. Trial job id = ', trial_job_id, '. Predict value = ', predict_y) logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None: if predict_y is None:
logger.info('wait for more information to predict precisely') logger.info('wait for more information to predict precisely')
return AssessResult.Good return AssessResult.Good
...@@ -130,7 +130,10 @@ class CurvefittingAssessor(Assessor): ...@@ -130,7 +130,10 @@ class CurvefittingAssessor(Assessor):
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
if (end_time - start_time).seconds > 60: if (end_time - start_time).seconds > 60:
logger.warning('Curve Fitting Assessor Runtime Exceeds 60s, Trial Id = ', self.trial_job_id, 'Trial History = ', self.trial_history) logger.warning(
'Curve Fitting Assessor Runtime Exceeds 60s, Trial Id = %s Trial History = %s',
trial_job_id, self.trial_history
)
if self.higher_better: if self.higher_better:
if predict_y > standard_performance: if predict_y > standard_performance:
...@@ -142,4 +145,4 @@ class CurvefittingAssessor(Assessor): ...@@ -142,4 +145,4 @@ class CurvefittingAssessor(Assessor):
return AssessResult.Bad return AssessResult.Bad
except Exception as exception: except Exception as exception:
logger.exception('unrecognize exception in curvefitting_assessor', exception) logger.exception('unrecognize exception in curvefitting_assessor %s', exception)
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import logging import logging
import numpy as np import numpy as np
from scipy import optimize from scipy import optimize
from .curvefunctions import * from .curvefunctions import * # pylint: disable=wildcard-import,unused-wildcard-import
# Number of curve functions we prepared, more details can be found in "curvefunctions.py" # Number of curve functions we prepared, more details can be found in "curvefunctions.py"
NUM_OF_FUNCTIONS = 12 NUM_OF_FUNCTIONS = 12
...@@ -33,7 +33,7 @@ LEAST_FITTED_FUNCTION = 4 ...@@ -33,7 +33,7 @@ LEAST_FITTED_FUNCTION = 4
logger = logging.getLogger('curvefitting_Assessor') logger = logging.getLogger('curvefitting_Assessor')
class CurveModel(object): class CurveModel:
"""Build a Curve Model to predict the performance """Build a Curve Model to predict the performance
Algorithm: https://github.com/Microsoft/nni/blob/master/src/sdk/pynni/nni/curvefitting_assessor/README.md Algorithm: https://github.com/Microsoft/nni/blob/master/src/sdk/pynni/nni/curvefitting_assessor/README.md
...@@ -83,7 +83,7 @@ class CurveModel(object): ...@@ -83,7 +83,7 @@ class CurveModel(object):
# Ignore exceptions caused by numerical calculations # Ignore exceptions caused by numerical calculations
pass pass
except Exception as exception: except Exception as exception:
logger.critical("Exceptions in fit_theta:", exception) logger.critical("Exceptions in fit_theta: %s", exception)
def filter_curve(self): def filter_curve(self):
"""filter the poor performing curve """filter the poor performing curve
...@@ -113,7 +113,7 @@ class CurveModel(object): ...@@ -113,7 +113,7 @@ class CurveModel(object):
if y < median + epsilon and y > median - epsilon: if y < median + epsilon and y > median - epsilon:
self.effective_model.append(model) self.effective_model.append(model)
self.effective_model_num = len(self.effective_model) self.effective_model_num = len(self.effective_model)
logger.info('List of effective model: ', self.effective_model) logger.info('List of effective model: %s', self.effective_model)
def predict_y(self, model, pos): def predict_y(self, model, pos):
"""return the predict y of 'model' when epoch = pos """return the predict y of 'model' when epoch = pos
...@@ -303,7 +303,7 @@ class CurveModel(object): ...@@ -303,7 +303,7 @@ class CurveModel(object):
""" """
init_weight = np.ones((self.effective_model_num), dtype=np.float) / self.effective_model_num init_weight = np.ones((self.effective_model_num), dtype=np.float) / self.effective_model_num
self.weight_samples = np.broadcast_to(init_weight, (NUM_OF_INSTANCE, self.effective_model_num)) self.weight_samples = np.broadcast_to(init_weight, (NUM_OF_INSTANCE, self.effective_model_num))
for i in range(NUM_OF_SIMULATION_TIME): for _ in range(NUM_OF_SIMULATION_TIME):
# sample new value from Q(i, j) # sample new value from Q(i, j)
new_values = np.random.randn(NUM_OF_INSTANCE, self.effective_model_num) * STEP_SIZE + self.weight_samples new_values = np.random.randn(NUM_OF_INSTANCE, self.effective_model_num) * STEP_SIZE + self.weight_samples
new_values = self.normalize_weight(new_values) new_values = self.normalize_weight(new_values)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import numpy as np import numpy as np
import unittest import unittest
......
...@@ -40,7 +40,7 @@ def json2space(x, oldy=None, name=NodeType.ROOT): ...@@ -40,7 +40,7 @@ def json2space(x, oldy=None, name=NodeType.ROOT):
_type = x[NodeType.TYPE] _type = x[NodeType.TYPE]
name = name + '-' + _type name = name + '-' + _type
if _type == 'choice': if _type == 'choice':
if oldy != None: if oldy is not None:
_index = oldy[NodeType.INDEX] _index = oldy[NodeType.INDEX]
y += json2space(x[NodeType.VALUE][_index], y += json2space(x[NodeType.VALUE][_index],
oldy[NodeType.VALUE], name=name+'[%d]' % _index) oldy[NodeType.VALUE], name=name+'[%d]' % _index)
...@@ -49,15 +49,13 @@ def json2space(x, oldy=None, name=NodeType.ROOT): ...@@ -49,15 +49,13 @@ def json2space(x, oldy=None, name=NodeType.ROOT):
y.append(name) y.append(name)
else: else:
for key in x.keys(): for key in x.keys():
y += json2space(x[key], (oldy[key] if oldy != y += json2space(x[key], oldy[key] if oldy else None, name+"[%s]" % str(key))
None else None), name+"[%s]" % str(key))
elif isinstance(x, list): elif isinstance(x, list):
for i, x_i in enumerate(x): for i, x_i in enumerate(x):
if isinstance(x_i, dict): if isinstance(x_i, dict):
if NodeType.NAME not in x_i.keys(): if NodeType.NAME not in x_i.keys():
raise RuntimeError('\'_name\' key is not found in this nested search space.') raise RuntimeError('\'_name\' key is not found in this nested search space.')
y += json2space(x_i, (oldy[i] if oldy != y += json2space(x_i, oldy[i] if oldy else None, name + "[%d]" % i)
None else None), name+"[%d]" % i)
return y return y
def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeType.ROOT): def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeType.ROOT):
...@@ -74,36 +72,49 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp ...@@ -74,36 +72,49 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp
_index = random_state.randint(len(_value)) _index = random_state.randint(len(_value))
y = { y = {
NodeType.INDEX: _index, NodeType.INDEX: _index,
NodeType.VALUE: json2parameter(x[NodeType.VALUE][_index], NodeType.VALUE: json2parameter(
is_rand, x[NodeType.VALUE][_index],
random_state, is_rand,
None, random_state,
Rand, None,
name=name+"[%d]" % _index) Rand,
name=name+"[%d]" % _index
)
} }
else: else:
y = eval('parameter_expressions.' + y = getattr(parameter_expressions, _type)(*(_value + [random_state]))
_type)(*(_value + [random_state]))
else: else:
y = copy.deepcopy(oldy) y = copy.deepcopy(oldy)
else: else:
y = dict() y = dict()
for key in x.keys(): for key in x.keys():
y[key] = json2parameter(x[key], is_rand, random_state, oldy[key] y[key] = json2parameter(
if oldy != None else None, Rand, name + "[%s]" % str(key)) x[key],
is_rand,
random_state,
oldy[key] if oldy else None,
Rand,
name + "[%s]" % str(key)
)
elif isinstance(x, list): elif isinstance(x, list):
y = list() y = list()
for i, x_i in enumerate(x): for i, x_i in enumerate(x):
if isinstance(x_i, dict): if isinstance(x_i, dict):
if NodeType.NAME not in x_i.keys(): if NodeType.NAME not in x_i.keys():
raise RuntimeError('\'_name\' key is not found in this nested search space.') raise RuntimeError('\'_name\' key is not found in this nested search space.')
y.append(json2parameter(x_i, is_rand, random_state, oldy[i] y.append(json2parameter(
if oldy != None else None, Rand, name + "[%d]" % i)) x_i,
is_rand,
random_state,
oldy[i] if oldy else None,
Rand,
name + "[%d]" % i
))
else: else:
y = copy.deepcopy(x) y = copy.deepcopy(x)
return y return y
class Individual(object): class Individual:
""" """
Indicidual class to store the indv info. Indicidual class to store the indv info.
""" """
......
...@@ -151,16 +151,14 @@ class GPTuner(Tuner): ...@@ -151,16 +151,14 @@ class GPTuner(Tuner):
""" """
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" % logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
(_completed_num, len(data)))
_completed_num += 1 _completed_num += 1
assert "parameter" in trial_info assert "parameter" in trial_info
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value: if not _value:
logger.info( logger.info("Useless trial data, value is %s, skip this trial data.", _value)
"Useless trial data, value is %s, skip this trial data." % _value)
continue continue
self.supplement_data_num += 1 self.supplement_data_num += 1
_parameter_id = '_'.join( _parameter_id = '_'.join(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment