Unverified Commit 8c203f30 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #211 from microsoft/master

merge master
parents 7c1ab114 483232c8
...@@ -41,8 +41,8 @@ class _LoggerFileWrapper(TextIOBase): ...@@ -41,8 +41,8 @@ class _LoggerFileWrapper(TextIOBase):
def write(self, s): def write(self, s):
if s != '\n': if s != '\n':
time = datetime.now().strftime(_time_format) cur_time = datetime.now().strftime(_time_format)
self.file.write('[{}] PRINT '.format(time) + s + '\n') self.file.write('[{}] PRINT '.format(cur_time) + s + '\n')
self.file.flush() self.file.flush()
return len(s) return len(s)
......
...@@ -92,5 +92,5 @@ class AGP_Pruner(Pruner): ...@@ -92,5 +92,5 @@ class AGP_Pruner(Pruner):
def update_epoch(self, epoch, sess): def update_epoch(self, epoch, sess):
sess.run(self.assign_handler) sess.run(self.assign_handler)
sess.run(tf.assign(self.now_epoch, int(epoch))) sess.run(tf.assign(self.now_epoch, int(epoch)))
for k in self.if_init_list.keys(): for k in self.if_init_list:
self.if_init_list[k] = True self.if_init_list[k] = True
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import tensorflow as tf import tensorflow as tf
from .compressor import Quantizer from .compressor import Quantizer
__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -12,7 +12,7 @@ class NaiveQuantizer(Quantizer): ...@@ -12,7 +12,7 @@ class NaiveQuantizer(Quantizer):
""" """
def __init__(self, config_list): def __init__(self, config_list):
super().__init__(config_list) super().__init__(config_list)
self.layer_scale = { } self.layer_scale = {}
def quantize_weight(self, weight, config, op_name, **kwargs): def quantize_weight(self, weight, config, op_name, **kwargs):
new_scale = tf.reduce_max(tf.abs(weight)) / 127 new_scale = tf.reduce_max(tf.abs(weight)) / 127
......
import tensorflow as tf
import logging import logging
import tensorflow as tf
from . import default_layers from . import default_layers
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -51,17 +51,14 @@ class Compressor: ...@@ -51,17 +51,14 @@ class Compressor:
Compressors can optionally overload this method to do model-specific initialization. Compressors can optionally overload this method to do model-specific initialization.
It is guaranteed that only one model will be bound to each compressor instance. It is guaranteed that only one model will be bound to each compressor instance.
""" """
pass
def update_epoch(self, epoch, sess): def update_epoch(self, epoch, sess):
"""If user want to update mask every epoch, user can override this method """If user want to update mask every epoch, user can override this method
""" """
pass
def step(self, sess): def step(self, sess):
"""If user want to update mask every step, user can override this method """If user want to update mask every step, user can override this method
""" """
pass
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
...@@ -84,10 +81,9 @@ class Compressor: ...@@ -84,10 +81,9 @@ class Compressor:
class Pruner(Compressor): class Pruner(Compressor):
"""Abstract base TensorFlow pruner""" """
Abstract base TensorFlow pruner
def __init__(self, config_list): """
super().__init__(config_list)
def calc_mask(self, weight, config, op, op_type, op_name): def calc_mask(self, weight, config, op, op_type, op_name):
"""Pruners should overload this method to provide mask for weight tensors. """Pruners should overload this method to provide mask for weight tensors.
...@@ -105,7 +101,7 @@ class Pruner(Compressor): ...@@ -105,7 +101,7 @@ class Pruner(Compressor):
# not sure what will happen if the weight is calculated from other operations # not sure what will happen if the weight is calculated from other operations
weight_index = _detect_weight_index(layer) weight_index = _detect_weight_index(layer)
if weight_index is None: if weight_index is None:
_logger.warning('Failed to detect weight for layer {}'.format(layer.name)) _logger.warning('Failed to detect weight for layer %s', layer.name)
return return
weight_op = layer.op.inputs[weight_index].op weight_op = layer.op.inputs[weight_index].op
weight = weight_op.inputs[0] weight = weight_op.inputs[0]
...@@ -115,10 +111,9 @@ class Pruner(Compressor): ...@@ -115,10 +111,9 @@ class Pruner(Compressor):
class Quantizer(Compressor): class Quantizer(Compressor):
"""Abstract base TensorFlow quantizer""" """
Abstract base TensorFlow quantizer
def __init__(self, config_list): """
super().__init__(config_list)
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
raise NotImplementedError("Quantizer must overload quantize_weight()") raise NotImplementedError("Quantizer must overload quantize_weight()")
...@@ -126,7 +121,7 @@ class Quantizer(Compressor): ...@@ -126,7 +121,7 @@ class Quantizer(Compressor):
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
weight_index = _detect_weight_index(layer) weight_index = _detect_weight_index(layer)
if weight_index is None: if weight_index is None:
_logger.warning('Failed to detect weight for layer {}'.format(layer.name)) _logger.warning('Failed to detect weight for layer %s', layer.name)
return return
weight_op = layer.op.inputs[weight_index].op weight_op = layer.op.inputs[weight_index].op
weight = weight_op.inputs[0] weight = weight_op.inputs[0]
...@@ -138,7 +133,7 @@ def _detect_weight_index(layer): ...@@ -138,7 +133,7 @@ def _detect_weight_index(layer):
index = default_layers.op_weight_index.get(layer.type) index = default_layers.op_weight_index.get(layer.type)
if index is not None: if index is not None:
return index return index
weight_indices = [ i for i, op in enumerate(layer.op.inputs) if op.name.endswith('Variable/read') ] weight_indices = [i for i, op in enumerate(layer.op.inputs) if op.name.endswith('Variable/read')]
if len(weight_indices) == 1: if len(weight_indices) == 1:
return weight_indices[0] return weight_indices[0]
return None return None
...@@ -102,5 +102,5 @@ class AGP_Pruner(Pruner): ...@@ -102,5 +102,5 @@ class AGP_Pruner(Pruner):
def update_epoch(self, epoch): def update_epoch(self, epoch):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for k in self.if_init_list.keys(): for k in self.if_init_list:
self.if_init_list[k] = True self.if_init_list[k] = True
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import torch import torch
from .compressor import Quantizer from .compressor import Quantizer
__all__ = [ 'NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer' ] __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer']
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -62,7 +62,7 @@ class DoReFaQuantizer(Quantizer): ...@@ -62,7 +62,7 @@ class DoReFaQuantizer(Quantizer):
def quantize_weight(self, weight, config, **kwargs): def quantize_weight(self, weight, config, **kwargs):
out = weight.tanh() out = weight.tanh()
out = out /( 2 * out.abs().max()) + 0.5 out = out / (2 * out.abs().max()) + 0.5
out = self.quantize(out, config['q_bits']) out = self.quantize(out, config['q_bits'])
out = 2 * out -1 out = 2 * out -1
return out return out
......
import torch
import logging import logging
import torch
from . import default_layers from . import default_layers
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -43,17 +43,14 @@ class Compressor: ...@@ -43,17 +43,14 @@ class Compressor:
Users can optionally overload this method to do model-specific initialization. Users can optionally overload this method to do model-specific initialization.
It is guaranteed that only one model will be bound to each compressor instance. It is guaranteed that only one model will be bound to each compressor instance.
""" """
pass
def update_epoch(self, epoch): def update_epoch(self, epoch):
"""if user want to update model every epoch, user can override this method """if user want to update model every epoch, user can override this method
""" """
pass
def step(self): def step(self):
"""if user want to update model every step, user can override this method """if user want to update model every step, user can override this method
""" """
pass
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
raise NotImplementedError() raise NotImplementedError()
...@@ -61,10 +58,8 @@ class Compressor: ...@@ -61,10 +58,8 @@ class Compressor:
def _select_config(self, layer): def _select_config(self, layer):
ret = None ret = None
for config in self._config_list: for config in self._config_list:
op_types = config.get('op_types') config['op_types'] = self._expand_config_op_types(config)
if op_types == 'default': if layer.type not in config['op_types']:
op_types = default_layers.weighted_modules
if op_types and layer.type not in op_types:
continue continue
if config.get('op_names') and layer.name not in config['op_names']: if config.get('op_names') and layer.name not in config['op_names']:
continue continue
...@@ -73,12 +68,21 @@ class Compressor: ...@@ -73,12 +68,21 @@ class Compressor:
return None return None
return ret return ret
def _expand_config_op_types(self, config):
if config is None:
return []
expanded_op_types = []
for op_type in config.get('op_types', []):
if op_type == 'default':
expanded_op_types.extend(default_layers.weighted_modules)
else:
expanded_op_types.append(op_type)
return expanded_op_types
class Pruner(Compressor): class Pruner(Compressor):
"""Abstract base PyTorch pruner""" """
Abstract base PyTorch pruner
def __init__(self, config_list): """
super().__init__(config_list)
def calc_mask(self, weight, config, op, op_type, op_name): def calc_mask(self, weight, config, op, op_type, op_name):
"""Pruners should overload this method to provide mask for weight tensors. """Pruners should overload this method to provide mask for weight tensors.
...@@ -93,17 +97,17 @@ class Pruner(Compressor): ...@@ -93,17 +97,17 @@ class Pruner(Compressor):
# create a wrapper forward function to replace the original one # create a wrapper forward function to replace the original one
assert layer._forward is None, 'Each model can only be compressed once' assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module): if not _check_weight(layer.module):
_logger.warning('Module {} does not have parameter "weight"'.format(layer.name)) _logger.warning('Module %s does not have parameter "weight"', layer.name)
return return
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*input): def new_forward(*inputs):
# apply mask to weight # apply mask to weight
old_weight = layer.module.weight.data old_weight = layer.module.weight.data
mask = self.calc_mask(old_weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) mask = self.calc_mask(old_weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = old_weight.mul(mask) layer.module.weight.data = old_weight.mul(mask)
# calculate forward # calculate forward
ret = layer._forward(*input) ret = layer._forward(*inputs)
# recover original weight # recover original weight
layer.module.weight.data = old_weight layer.module.weight.data = old_weight
return ret return ret
...@@ -112,14 +116,9 @@ class Pruner(Compressor): ...@@ -112,14 +116,9 @@ class Pruner(Compressor):
class Quantizer(Compressor): class Quantizer(Compressor):
"""Base quantizer for pytorch quantizer""" """
Base quantizer for pytorch quantizer
def __init__(self, config_list): """
super().__init__(config_list)
def __call__(self, model):
self.compress(model)
return model
def quantize_weight(self, weight, config, op, op_type, op_name): def quantize_weight(self, weight, config, op, op_type, op_name):
"""user should know where dequantize goes and implement it in quantize method """user should know where dequantize goes and implement it in quantize method
...@@ -130,15 +129,15 @@ class Quantizer(Compressor): ...@@ -130,15 +129,15 @@ class Quantizer(Compressor):
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
assert layer._forward is None, 'Each model can only be compressed once' assert layer._forward is None, 'Each model can only be compressed once'
if not _check_weight(layer.module): if not _check_weight(layer.module):
_logger.warning('Module {} does not have parameter "weight"'.format(layer.name)) _logger.warning('Module %s does not have parameter "weight"', layer.name)
return return
layer._forward = layer.module.forward layer._forward = layer.module.forward
def new_forward(*input): def new_forward(*inputs):
weight = layer.module.weight.data weight = layer.module.weight.data
new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name) new_weight = self.quantize_weight(weight, config, op=layer.module, op_type=layer.type, op_name=layer.name)
layer.module.weight.data = new_weight layer.module.weight.data = new_weight
return layer._forward(*input) return layer._forward(*inputs)
layer.module.forward = new_forward layer.module.forward = new_forward
......
...@@ -50,7 +50,7 @@ class CurvefittingAssessor(Assessor): ...@@ -50,7 +50,7 @@ class CurvefittingAssessor(Assessor):
self.higher_better = False self.higher_better = False
else: else:
self.higher_better = True self.higher_better = True
logger.warning('unrecognized optimize_mode', optimize_mode) logger.warning('unrecognized optimize_mode %s', optimize_mode)
# Start forecasting when historical data reaches start step # Start forecasting when historical data reaches start step
self.start_step = start_step self.start_step = start_step
# Record the compared threshold # Record the compared threshold
...@@ -81,9 +81,9 @@ class CurvefittingAssessor(Assessor): ...@@ -81,9 +81,9 @@ class CurvefittingAssessor(Assessor):
else: else:
self.set_best_performance = True self.set_best_performance = True
self.completed_best_performance = self.trial_history[-1] self.completed_best_performance = self.trial_history[-1]
logger.info('Updated complted best performance, trial job id:', trial_job_id) logger.info('Updated complted best performance, trial job id: %s', trial_job_id)
else: else:
logger.info('No need to update, trial job id: ', trial_job_id) logger.info('No need to update, trial job id: %s', trial_job_id)
def assess_trial(self, trial_job_id, trial_history): def assess_trial(self, trial_job_id, trial_history):
"""assess whether a trial should be early stop by curve fitting algorithm """assess whether a trial should be early stop by curve fitting algorithm
...@@ -105,7 +105,7 @@ class CurvefittingAssessor(Assessor): ...@@ -105,7 +105,7 @@ class CurvefittingAssessor(Assessor):
Exception Exception
unrecognize exception in curvefitting_assessor unrecognize exception in curvefitting_assessor
""" """
self.trial_job_id = trial_job_id trial_job_id = trial_job_id
self.trial_history = trial_history self.trial_history = trial_history
if not self.set_best_performance: if not self.set_best_performance:
return AssessResult.Good return AssessResult.Good
...@@ -122,7 +122,7 @@ class CurvefittingAssessor(Assessor): ...@@ -122,7 +122,7 @@ class CurvefittingAssessor(Assessor):
# Predict the final result # Predict the final result
curvemodel = CurveModel(self.target_pos) curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(trial_history) predict_y = curvemodel.predict(trial_history)
logger.info('Prediction done. Trial job id = ', trial_job_id, '. Predict value = ', predict_y) logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None: if predict_y is None:
logger.info('wait for more information to predict precisely') logger.info('wait for more information to predict precisely')
return AssessResult.Good return AssessResult.Good
...@@ -130,7 +130,10 @@ class CurvefittingAssessor(Assessor): ...@@ -130,7 +130,10 @@ class CurvefittingAssessor(Assessor):
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
if (end_time - start_time).seconds > 60: if (end_time - start_time).seconds > 60:
logger.warning('Curve Fitting Assessor Runtime Exceeds 60s, Trial Id = ', self.trial_job_id, 'Trial History = ', self.trial_history) logger.warning(
'Curve Fitting Assessor Runtime Exceeds 60s, Trial Id = %s Trial History = %s',
trial_job_id, self.trial_history
)
if self.higher_better: if self.higher_better:
if predict_y > standard_performance: if predict_y > standard_performance:
...@@ -142,4 +145,4 @@ class CurvefittingAssessor(Assessor): ...@@ -142,4 +145,4 @@ class CurvefittingAssessor(Assessor):
return AssessResult.Bad return AssessResult.Bad
except Exception as exception: except Exception as exception:
logger.exception('unrecognize exception in curvefitting_assessor', exception) logger.exception('unrecognize exception in curvefitting_assessor %s', exception)
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import logging import logging
import numpy as np import numpy as np
from scipy import optimize from scipy import optimize
from .curvefunctions import * from .curvefunctions import * # pylint: disable=wildcard-import,unused-wildcard-import
# Number of curve functions we prepared, more details can be found in "curvefunctions.py" # Number of curve functions we prepared, more details can be found in "curvefunctions.py"
NUM_OF_FUNCTIONS = 12 NUM_OF_FUNCTIONS = 12
...@@ -33,7 +33,7 @@ LEAST_FITTED_FUNCTION = 4 ...@@ -33,7 +33,7 @@ LEAST_FITTED_FUNCTION = 4
logger = logging.getLogger('curvefitting_Assessor') logger = logging.getLogger('curvefitting_Assessor')
class CurveModel(object): class CurveModel:
"""Build a Curve Model to predict the performance """Build a Curve Model to predict the performance
Algorithm: https://github.com/Microsoft/nni/blob/master/src/sdk/pynni/nni/curvefitting_assessor/README.md Algorithm: https://github.com/Microsoft/nni/blob/master/src/sdk/pynni/nni/curvefitting_assessor/README.md
...@@ -83,7 +83,7 @@ class CurveModel(object): ...@@ -83,7 +83,7 @@ class CurveModel(object):
# Ignore exceptions caused by numerical calculations # Ignore exceptions caused by numerical calculations
pass pass
except Exception as exception: except Exception as exception:
logger.critical("Exceptions in fit_theta:", exception) logger.critical("Exceptions in fit_theta: %s", exception)
def filter_curve(self): def filter_curve(self):
"""filter the poor performing curve """filter the poor performing curve
...@@ -113,7 +113,7 @@ class CurveModel(object): ...@@ -113,7 +113,7 @@ class CurveModel(object):
if y < median + epsilon and y > median - epsilon: if y < median + epsilon and y > median - epsilon:
self.effective_model.append(model) self.effective_model.append(model)
self.effective_model_num = len(self.effective_model) self.effective_model_num = len(self.effective_model)
logger.info('List of effective model: ', self.effective_model) logger.info('List of effective model: %s', self.effective_model)
def predict_y(self, model, pos): def predict_y(self, model, pos):
"""return the predict y of 'model' when epoch = pos """return the predict y of 'model' when epoch = pos
...@@ -303,7 +303,7 @@ class CurveModel(object): ...@@ -303,7 +303,7 @@ class CurveModel(object):
""" """
init_weight = np.ones((self.effective_model_num), dtype=np.float) / self.effective_model_num init_weight = np.ones((self.effective_model_num), dtype=np.float) / self.effective_model_num
self.weight_samples = np.broadcast_to(init_weight, (NUM_OF_INSTANCE, self.effective_model_num)) self.weight_samples = np.broadcast_to(init_weight, (NUM_OF_INSTANCE, self.effective_model_num))
for i in range(NUM_OF_SIMULATION_TIME): for _ in range(NUM_OF_SIMULATION_TIME):
# sample new value from Q(i, j) # sample new value from Q(i, j)
new_values = np.random.randn(NUM_OF_INSTANCE, self.effective_model_num) * STEP_SIZE + self.weight_samples new_values = np.random.randn(NUM_OF_INSTANCE, self.effective_model_num) * STEP_SIZE + self.weight_samples
new_values = self.normalize_weight(new_values) new_values = self.normalize_weight(new_values)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import numpy as np import numpy as np
import unittest import unittest
......
...@@ -40,7 +40,7 @@ def json2space(x, oldy=None, name=NodeType.ROOT): ...@@ -40,7 +40,7 @@ def json2space(x, oldy=None, name=NodeType.ROOT):
_type = x[NodeType.TYPE] _type = x[NodeType.TYPE]
name = name + '-' + _type name = name + '-' + _type
if _type == 'choice': if _type == 'choice':
if oldy != None: if oldy is not None:
_index = oldy[NodeType.INDEX] _index = oldy[NodeType.INDEX]
y += json2space(x[NodeType.VALUE][_index], y += json2space(x[NodeType.VALUE][_index],
oldy[NodeType.VALUE], name=name+'[%d]' % _index) oldy[NodeType.VALUE], name=name+'[%d]' % _index)
...@@ -49,15 +49,13 @@ def json2space(x, oldy=None, name=NodeType.ROOT): ...@@ -49,15 +49,13 @@ def json2space(x, oldy=None, name=NodeType.ROOT):
y.append(name) y.append(name)
else: else:
for key in x.keys(): for key in x.keys():
y += json2space(x[key], (oldy[key] if oldy != y += json2space(x[key], oldy[key] if oldy else None, name+"[%s]" % str(key))
None else None), name+"[%s]" % str(key))
elif isinstance(x, list): elif isinstance(x, list):
for i, x_i in enumerate(x): for i, x_i in enumerate(x):
if isinstance(x_i, dict): if isinstance(x_i, dict):
if NodeType.NAME not in x_i.keys(): if NodeType.NAME not in x_i.keys():
raise RuntimeError('\'_name\' key is not found in this nested search space.') raise RuntimeError('\'_name\' key is not found in this nested search space.')
y += json2space(x_i, (oldy[i] if oldy != y += json2space(x_i, oldy[i] if oldy else None, name + "[%d]" % i)
None else None), name+"[%d]" % i)
return y return y
def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeType.ROOT): def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeType.ROOT):
...@@ -74,36 +72,49 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp ...@@ -74,36 +72,49 @@ def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeTyp
_index = random_state.randint(len(_value)) _index = random_state.randint(len(_value))
y = { y = {
NodeType.INDEX: _index, NodeType.INDEX: _index,
NodeType.VALUE: json2parameter(x[NodeType.VALUE][_index], NodeType.VALUE: json2parameter(
x[NodeType.VALUE][_index],
is_rand, is_rand,
random_state, random_state,
None, None,
Rand, Rand,
name=name+"[%d]" % _index) name=name+"[%d]" % _index
)
} }
else: else:
y = eval('parameter_expressions.' + y = getattr(parameter_expressions, _type)(*(_value + [random_state]))
_type)(*(_value + [random_state]))
else: else:
y = copy.deepcopy(oldy) y = copy.deepcopy(oldy)
else: else:
y = dict() y = dict()
for key in x.keys(): for key in x.keys():
y[key] = json2parameter(x[key], is_rand, random_state, oldy[key] y[key] = json2parameter(
if oldy != None else None, Rand, name + "[%s]" % str(key)) x[key],
is_rand,
random_state,
oldy[key] if oldy else None,
Rand,
name + "[%s]" % str(key)
)
elif isinstance(x, list): elif isinstance(x, list):
y = list() y = list()
for i, x_i in enumerate(x): for i, x_i in enumerate(x):
if isinstance(x_i, dict): if isinstance(x_i, dict):
if NodeType.NAME not in x_i.keys(): if NodeType.NAME not in x_i.keys():
raise RuntimeError('\'_name\' key is not found in this nested search space.') raise RuntimeError('\'_name\' key is not found in this nested search space.')
y.append(json2parameter(x_i, is_rand, random_state, oldy[i] y.append(json2parameter(
if oldy != None else None, Rand, name + "[%d]" % i)) x_i,
is_rand,
random_state,
oldy[i] if oldy else None,
Rand,
name + "[%d]" % i
))
else: else:
y = copy.deepcopy(x) y = copy.deepcopy(x)
return y return y
class Individual(object): class Individual:
""" """
Indicidual class to store the indv info. Indicidual class to store the indv info.
""" """
......
...@@ -151,16 +151,14 @@ class GPTuner(Tuner): ...@@ -151,16 +151,14 @@ class GPTuner(Tuner):
""" """
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" % logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
(_completed_num, len(data)))
_completed_num += 1 _completed_num += 1
assert "parameter" in trial_info assert "parameter" in trial_info
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value: if not _value:
logger.info( logger.info("Useless trial data, value is %s, skip this trial data.", _value)
"Useless trial data, value is %s, skip this trial data." % _value)
continue continue
self.supplement_data_num += 1 self.supplement_data_num += 1
_parameter_id = '_'.join( _parameter_id = '_'.join(
......
...@@ -139,7 +139,7 @@ class TargetSpace(): ...@@ -139,7 +139,7 @@ class TargetSpace():
except AssertionError: except AssertionError:
raise ValueError( raise ValueError(
"Size of array ({}) is different than the ".format(len(x)) + "Size of array ({}) is different than the ".format(len(x)) +
"expected number of parameters ({}).".format(self.dim()) "expected number of parameters ({}).".format(self.dim)
) )
params = {} params = {}
......
...@@ -37,8 +37,8 @@ def _match_val_type(vals, bounds): ...@@ -37,8 +37,8 @@ def _match_val_type(vals, bounds):
_type = bound['_type'] _type = bound['_type']
if _type == "choice": if _type == "choice":
# Find the closest integer in the array, vals_bounds # Find the closest integer in the array, vals_bounds
vals_new.append( # pylint: disable=cell-var-from-loop
min(bound['_value'], key=lambda x: abs(x - vals[i]))) vals_new.append(min(bound['_value'], key=lambda x: abs(x - vals[i])))
elif _type in ['quniform', 'randint']: elif _type in ['quniform', 'randint']:
vals_new.append(np.around(vals[i])) vals_new.append(np.around(vals[i]))
else: else:
......
...@@ -23,8 +23,8 @@ gridsearch_tuner.py including: ...@@ -23,8 +23,8 @@ gridsearch_tuner.py including:
''' '''
import copy import copy
import numpy as np
import logging import logging
import numpy as np
import nni import nni
from nni.tuner import Tuner from nni.tuner import Tuner
...@@ -44,7 +44,8 @@ class GridSearchTuner(Tuner): ...@@ -44,7 +44,8 @@ class GridSearchTuner(Tuner):
Type 'choice' will select one of the options. Note that it can also be nested. Type 'choice' will select one of the options. Note that it can also be nested.
Type 'quniform' will receive three values [low, high, q], where [low, high] specifies a range and 'q' specifies the interval Type 'quniform' will receive three values [low, high, q], where [low, high] specifies a range and 'q' specifies the interval
It will be sampled in a way that the first sampled value is 'low', and each of the following values is 'interval' larger than the value in front of it. It will be sampled in a way that the first sampled value is 'low',
and each of the following values is 'interval' larger than the value in front of it.
Type 'randint' gives all possible intergers in range[low, high). Note that 'high' is not included. Type 'randint' gives all possible intergers in range[low, high). Note that 'high' is not included.
''' '''
...@@ -132,7 +133,7 @@ class GridSearchTuner(Tuner): ...@@ -132,7 +133,7 @@ class GridSearchTuner(Tuner):
def generate_parameters(self, parameter_id, **kwargs): def generate_parameters(self, parameter_id, **kwargs):
self.count += 1 self.count += 1
while (self.count <= len(self.expanded_search_space)-1): while self.count <= len(self.expanded_search_space) - 1:
_params_tuple = convert_dict2tuple(self.expanded_search_space[self.count]) _params_tuple = convert_dict2tuple(self.expanded_search_space[self.count])
if _params_tuple in self.supplement_data: if _params_tuple in self.supplement_data:
self.count += 1 self.count += 1
...@@ -153,14 +154,14 @@ class GridSearchTuner(Tuner): ...@@ -153,14 +154,14 @@ class GridSearchTuner(Tuner):
""" """
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %(_completed_num, len(data))) logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
_completed_num += 1 _completed_num += 1
assert "parameter" in trial_info assert "parameter" in trial_info
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value: if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value) logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue continue
_params_tuple = convert_dict2tuple(_params) _params_tuple = convert_dict2tuple(_params)
self.supplement_data[_params_tuple] = True self.supplement_data[_params_tuple] = True
......
...@@ -32,7 +32,7 @@ from nni.common import multi_phase_enabled ...@@ -32,7 +32,7 @@ from nni.common import multi_phase_enabled
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.protocol import CommandType, send from nni.protocol import CommandType, send
from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward
import nni.parameter_expressions as parameter_expressions from nni import parameter_expressions
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -49,7 +49,7 @@ def create_parameter_id(): ...@@ -49,7 +49,7 @@ def create_parameter_id():
int int
parameter id parameter id
""" """
global _next_parameter_id # pylint: disable=global-statement global _next_parameter_id
_next_parameter_id += 1 _next_parameter_id += 1
return _next_parameter_id - 1 return _next_parameter_id - 1
...@@ -102,8 +102,7 @@ def json2parameter(ss_spec, random_state): ...@@ -102,8 +102,7 @@ def json2parameter(ss_spec, random_state):
_index = random_state.randint(len(_value)) _index = random_state.randint(len(_value))
chosen_params = json2parameter(ss_spec[NodeType.VALUE][_index], random_state) chosen_params = json2parameter(ss_spec[NodeType.VALUE][_index], random_state)
else: else:
chosen_params = eval('parameter_expressions.' + # pylint: disable=eval-used chosen_params = getattr(parameter_expressions, _type)(*(_value + [random_state]))
_type)(*(_value + [random_state]))
else: else:
chosen_params = dict() chosen_params = dict()
for key in ss_spec.keys(): for key in ss_spec.keys():
...@@ -140,8 +139,8 @@ class Bracket(): ...@@ -140,8 +139,8 @@ class Bracket():
self.bracket_id = s self.bracket_id = s
self.s_max = s_max self.s_max = s_max
self.eta = eta self.eta = eta
self.n = math.ceil((s_max + 1) * (eta ** s) / (s + 1) - _epsilon) # pylint: disable=invalid-name self.n = math.ceil((s_max + 1) * (eta ** s) / (s + 1) - _epsilon)
self.r = R / eta ** s # pylint: disable=invalid-name self.r = R / eta ** s
self.i = 0 self.i = 0
self.hyper_configs = [] # [ {id: params}, {}, ... ] self.hyper_configs = [] # [ {id: params}, {}, ... ]
self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ] self.configs_perf = [] # [ {id: [seq, acc]}, {}, ... ]
...@@ -197,7 +196,7 @@ class Bracket(): ...@@ -197,7 +196,7 @@ class Bracket():
i: int i: int
the ith round the ith round
""" """
global _KEY # pylint: disable=global-statement global _KEY
self.num_finished_configs[i] += 1 self.num_finished_configs[i] += 1
_logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', self.bracket_id, self.i, i, _logger.debug('bracket id: %d, round: %d %d, finished: %d, all: %d', self.bracket_id, self.i, i,
self.num_finished_configs[i], self.num_configs_to_run[i]) self.num_finished_configs[i], self.num_configs_to_run[i])
...@@ -226,7 +225,7 @@ class Bracket(): ...@@ -226,7 +225,7 @@ class Bracket():
return [[key, value] for key, value in hyper_configs.items()] return [[key, value] for key, value in hyper_configs.items()]
return None return None
def get_hyperparameter_configurations(self, num, r, searchspace_json, random_state): # pylint: disable=invalid-name def get_hyperparameter_configurations(self, num, r, searchspace_json, random_state):
"""Randomly generate num hyperparameter configurations from search space """Randomly generate num hyperparameter configurations from search space
Parameters Parameters
...@@ -239,7 +238,7 @@ class Bracket(): ...@@ -239,7 +238,7 @@ class Bracket():
list list
a list of hyperparameter configurations. Format: [[key1, value1], [key2, value2], ...] a list of hyperparameter configurations. Format: [[key1, value1], [key2, value2], ...]
""" """
global _KEY # pylint: disable=global-statement global _KEY
assert self.i == 0 assert self.i == 0
hyperparameter_configs = dict() hyperparameter_configs = dict()
for _ in range(num): for _ in range(num):
...@@ -285,7 +284,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -285,7 +284,7 @@ class Hyperband(MsgDispatcherBase):
def __init__(self, R=60, eta=3, optimize_mode='maximize'): def __init__(self, R=60, eta=3, optimize_mode='maximize'):
"""B = (s_max + 1)R""" """B = (s_max + 1)R"""
super(Hyperband, self).__init__() super(Hyperband, self).__init__()
self.R = R # pylint: disable=invalid-name self.R = R
self.eta = eta self.eta = eta
self.brackets = dict() # dict of Bracket self.brackets = dict() # dict of Bracket
self.generated_hyper_configs = [] # all the configs waiting for run self.generated_hyper_configs = [] # all the configs waiting for run
......
...@@ -51,13 +51,13 @@ def json2space(in_x, name=NodeType.ROOT): ...@@ -51,13 +51,13 @@ def json2space(in_x, name=NodeType.ROOT):
name = name + '-' + _type name = name + '-' + _type
_value = json2space(in_x[NodeType.VALUE], name=name) _value = json2space(in_x[NodeType.VALUE], name=name)
if _type == 'choice': if _type == 'choice':
out_y = eval('hp.hp.choice')(name, _value) out_y = hp.hp.choice(name, _value)
elif _type == 'randint': elif _type == 'randint':
out_y = hp.hp.randint(name, _value[1] - _value[0]) out_y = hp.hp.randint(name, _value[1] - _value[0])
else: else:
if _type in ['loguniform', 'qloguniform']: if _type in ['loguniform', 'qloguniform']:
_value[:2] = np.log(_value[:2]) _value[:2] = np.log(_value[:2])
out_y = eval('hp.hp.' + _type)(name, *_value) out_y = getattr(hp.hp, _type)(name, *_value)
else: else:
out_y = dict() out_y = dict()
for key in in_x.keys(): for key in in_x.keys():
...@@ -191,6 +191,7 @@ def _add_index(in_x, parameter): ...@@ -191,6 +191,7 @@ def _add_index(in_x, parameter):
return {NodeType.INDEX: pos, NodeType.VALUE: item} return {NodeType.INDEX: pos, NodeType.VALUE: item}
else: else:
return parameter return parameter
return None # note: this is not written by original author, feel free to modify if you think it's incorrect
class HyperoptTuner(Tuner): class HyperoptTuner(Tuner):
...@@ -409,8 +410,8 @@ class HyperoptTuner(Tuner): ...@@ -409,8 +410,8 @@ class HyperoptTuner(Tuner):
misc_by_id = {m['tid']: m for m in miscs} misc_by_id = {m['tid']: m for m in miscs}
for m in miscs: for m in miscs:
m['idxs'] = dict([(key, []) for key in idxs]) m['idxs'] = {key: [] for key in idxs}
m['vals'] = dict([(key, []) for key in idxs]) m['vals'] = {key: [] for key in idxs}
for key in idxs: for key in idxs:
assert len(idxs[key]) == len(vals[key]) assert len(idxs[key]) == len(vals[key])
...@@ -433,7 +434,7 @@ class HyperoptTuner(Tuner): ...@@ -433,7 +434,7 @@ class HyperoptTuner(Tuner):
total_params : dict total_params : dict
parameter suggestion parameter suggestion
""" """
if self.parallel and len(self.total_data)>20 and len(self.running_data) and self.optimal_y is not None: if self.parallel and len(self.total_data) > 20 and self.running_data and self.optimal_y is not None:
self.CL_rval = copy.deepcopy(self.rval) self.CL_rval = copy.deepcopy(self.rval)
if self.constant_liar_type == 'mean': if self.constant_liar_type == 'mean':
_constant_liar_y = self.optimal_y[0] / self.optimal_y[1] _constant_liar_y = self.optimal_y[0] / self.optimal_y[1]
...@@ -481,8 +482,7 @@ class HyperoptTuner(Tuner): ...@@ -481,8 +482,7 @@ class HyperoptTuner(Tuner):
""" """
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" % logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
(_completed_num, len(data)))
_completed_num += 1 _completed_num += 1
if self.algorithm_name == 'random_search': if self.algorithm_name == 'random_search':
return return
...@@ -491,9 +491,7 @@ class HyperoptTuner(Tuner): ...@@ -491,9 +491,7 @@ class HyperoptTuner(Tuner):
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value: if not _value:
logger.info( logger.info("Useless trial data, value is %s, skip this trial data.", _value)
"Useless trial data, value is %s, skip this trial data." %
_value)
continue continue
self.supplement_data_num += 1 self.supplement_data_num += 1
_parameter_id = '_'.join( _parameter_id = '_'.join(
......
...@@ -42,7 +42,7 @@ class MedianstopAssessor(Assessor): ...@@ -42,7 +42,7 @@ class MedianstopAssessor(Assessor):
self.high_better = False self.high_better = False
else: else:
self.high_better = True self.high_better = True
logger.warning('unrecognized optimize_mode', optimize_mode) logger.warning('unrecognized optimize_mode %s', optimize_mode)
def _update_data(self, trial_job_id, trial_history): def _update_data(self, trial_job_id, trial_history):
"""update data """update data
...@@ -121,10 +121,10 @@ class MedianstopAssessor(Assessor): ...@@ -121,10 +121,10 @@ class MedianstopAssessor(Assessor):
best_history = min(trial_history) best_history = min(trial_history)
avg_array = [] avg_array = []
for id in self.completed_avg_history: for id_ in self.completed_avg_history:
if len(self.completed_avg_history[id]) >= curr_step: if len(self.completed_avg_history[id_]) >= curr_step:
avg_array.append(self.completed_avg_history[id][curr_step - 1]) avg_array.append(self.completed_avg_history[id_][curr_step - 1])
if len(avg_array) > 0: if avg_array:
avg_array.sort() avg_array.sort()
if self.high_better: if self.high_better:
median = avg_array[(len(avg_array)-1) // 2] median = avg_array[(len(avg_array)-1) // 2]
......
...@@ -22,7 +22,6 @@ import random ...@@ -22,7 +22,6 @@ import random
from .medianstop_assessor import MedianstopAssessor from .medianstop_assessor import MedianstopAssessor
from nni.assessor import AssessResult from nni.assessor import AssessResult
logger = logging.getLogger('nni.contrib.medianstop_assessor') logger = logging.getLogger('nni.contrib.medianstop_assessor')
logger.debug('START') logger.debug('START')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment