Unverified Commit 8c203f30 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #211 from microsoft/master

merge master
parents 7c1ab114 483232c8
...@@ -27,7 +27,6 @@ _logger = logging.getLogger(__name__) ...@@ -27,7 +27,6 @@ _logger = logging.getLogger(__name__)
class Tuner(Recoverable): class Tuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id, **kwargs): def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial (hyper-)parameters, as a serializable object. """Returns a set of trial (hyper-)parameters, as a serializable object.
...@@ -47,7 +46,7 @@ class Tuner(Recoverable): ...@@ -47,7 +46,7 @@ class Tuner(Recoverable):
result = [] result = []
for parameter_id in parameter_id_list: for parameter_id in parameter_id_list:
try: try:
_logger.debug("generating param for {}".format(parameter_id)) _logger.debug("generating param for %s", parameter_id)
res = self.generate_parameters(parameter_id, **kwargs) res = self.generate_parameters(parameter_id, **kwargs)
except nni.NoMoreTrialError: except nni.NoMoreTrialError:
return result return result
...@@ -71,6 +70,8 @@ class Tuner(Recoverable): ...@@ -71,6 +70,8 @@ class Tuner(Recoverable):
By default `receive_trial_result()` will only receive results of algorithm-generated hyper-parameters. By default `receive_trial_result()` will only receive results of algorithm-generated hyper-parameters.
If tuners want to receive those of customized parameters as well, they can call this function in `__init__()`. If tuners want to receive those of customized parameters as well, they can call this function in `__init__()`.
""" """
# pylint: disable=attribute-defined-outside-init
# FIXME: because tuner is designed as interface, this API should not be here
self._accept_customized = accept self._accept_customized = accept
def trial_end(self, parameter_id, success, **kwargs): def trial_end(self, parameter_id, success, **kwargs):
...@@ -78,7 +79,6 @@ class Tuner(Recoverable): ...@@ -78,7 +79,6 @@ class Tuner(Recoverable):
parameter_id: int parameter_id: int
success: True if the trial successfully completed; False if failed or terminated success: True if the trial successfully completed; False if failed or terminated
""" """
pass
def update_search_space(self, search_space): def update_search_space(self, search_space):
"""Update the search space of tuner. Must override. """Update the search space of tuner. Must override.
...@@ -91,20 +91,19 @@ class Tuner(Recoverable): ...@@ -91,20 +91,19 @@ class Tuner(Recoverable):
path: checkpoint directory for tuner path: checkpoint directory for tuner
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path) _logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self): def save_checkpoint(self):
"""Save the checkpoint of tuner. """Save the checkpoint of tuner.
path: checkpoint directory for tuner path: checkpoint directory for tuner
""" """
checkpoin_path = self.get_checkpoint_path() checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path) _logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def import_data(self, data): def import_data(self, data):
"""Import additional data for tuning """Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
""" """
pass
def _on_exit(self): def _on_exit(self):
pass pass
......
...@@ -84,12 +84,13 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -84,12 +84,13 @@ def extract_scalar_reward(value, scalar_key='default'):
Incorrect final result: the final result should be float/int, Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int. or a dict which has a key named "default" whose value is float/int.
""" """
if isinstance(value, float) or isinstance(value, int): if isinstance(value, (float, int)):
reward = value reward = value
elif isinstance(value, dict) and scalar_key in value and isinstance(value[scalar_key], (float, int)): elif isinstance(value, dict) and scalar_key in value and isinstance(value[scalar_key], (float, int)):
reward = value[scalar_key] reward = value[scalar_key]
else: else:
raise RuntimeError('Incorrect final result: the final result should be float/int, or a dict which has a key named "default" whose value is float/int.') raise RuntimeError('Incorrect final result: the final result should be float/int, ' \
'or a dict which has a key named "default" whose value is float/int.')
return reward return reward
...@@ -101,7 +102,6 @@ def convert_dict2tuple(value): ...@@ -101,7 +102,6 @@ def convert_dict2tuple(value):
for _keys in value: for _keys in value:
value[_keys] = convert_dict2tuple(value[_keys]) value[_keys] = convert_dict2tuple(value[_keys])
return tuple(sorted(value.items())) return tuple(sorted(value.items()))
else:
return value return value
......
# This pylintrc file is a little more strick than the one in root of code directory
# SDK source MUST pass lint rules in top level directory, and SHOULD pass rules here
[SETTINGS]
max-line-length=140
disable =
missing-docstring,
invalid-name, # C0103
no-member, # E1101: sometimes pylint cannot detect members correctly due to a bug
c-extension-no-member, # I1101
no-self-use, # R0201: many functions in this SDK are designed for override
duplicate-code, # R0801
too-many-instance-attributes, # R0902
too-few-public-methods, # R0903
too-many-public-methods, # R0904
too-many-return-statements, # R0911
too-many-branches, # R0912
too-many-arguments, # R0913
too-many-locals, # R0914
too-many-statements, # R0915
too-many-nested-blocks, # R1702
no-else-return, # R1705
chained-comparison, # R1716
no-else-raise, # R1720
protected-access, # W0212: underscore variables may be protected by whole SDK instead of single module
arguments-differ, # W0221: pylint cannot handle *args and **kwargs
super-init-not-called, # W0231: some interface classes do not expect users to call init
useless-super-delegation, # W0235: derived init may have different docstring
global-statement, # W0603: globals are useful to hide SDK internal states from user
unused-argument, # W0613: many functions in this SDK are designed for override
broad-except, # W0703: the SDK commonly catch exceptions to report error
fixme, # W0511
ignore-patterns=test.*.py
# List of members which are set dynamically and missed by pylint inference
generated-members=numpy.*,torch.*
...@@ -100,21 +100,21 @@ class TorchMnist(torch.nn.Module): ...@@ -100,21 +100,21 @@ class TorchMnist(torch.nn.Module):
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
def test_tf_pruner(self): def test_tf_pruner(self):
model = TfMnist() model = TfMnist()
configure_list = [{'sparsity': 0.8, 'op_types': 'default'}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(configure_list).compress_default_graph() tf_compressor.LevelPruner(configure_list).compress_default_graph()
def test_tf_quantizer(self): def test_tf_quantizer(self):
model = TfMnist() model = TfMnist()
tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph() tf_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress_default_graph()
def test_torch_pruner(self): def test_torch_pruner(self):
model = TorchMnist() model = TorchMnist()
configure_list = [{'sparsity': 0.8, 'op_types': 'default'}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(configure_list).compress(model) torch_compressor.LevelPruner(configure_list).compress(model)
def test_torch_quantizer(self): def test_torch_quantizer(self):
model = TorchMnist() model = TorchMnist()
torch_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress(model) torch_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress(model)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -33,7 +33,7 @@ class App extends React.Component<{}, AppState> { ...@@ -33,7 +33,7 @@ class App extends React.Component<{}, AppState> {
} }
changeInterval = (interval: number) => { changeInterval = (interval: number) => {
this.setState({ interval: interval }); this.setState({ interval });
if (this.timerId === null && interval !== 0) { if (this.timerId === null && interval !== 0) {
window.setTimeout(this.refresh); window.setTimeout(this.refresh);
} else if (this.timerId !== null && interval === 0) { } else if (this.timerId !== null && interval === 0) {
......
...@@ -154,4 +154,3 @@ def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module): ...@@ -154,4 +154,3 @@ def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module):
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args)) raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else: else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc)) raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
...@@ -20,9 +20,10 @@ ...@@ -20,9 +20,10 @@
import ast import ast
import astor
import numbers import numbers
import astor
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment