Unverified Commit 8c203f30 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #211 from microsoft/master

merge master
parents 7c1ab114 483232c8
......@@ -27,7 +27,6 @@ _logger = logging.getLogger(__name__)
class Tuner(Recoverable):
# pylint: disable=no-self-use,unused-argument
def generate_parameters(self, parameter_id, **kwargs):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
......@@ -47,7 +46,7 @@ class Tuner(Recoverable):
result = []
for parameter_id in parameter_id_list:
try:
_logger.debug("generating param for {}".format(parameter_id))
_logger.debug("generating param for %s", parameter_id)
res = self.generate_parameters(parameter_id, **kwargs)
except nni.NoMoreTrialError:
return result
......@@ -71,6 +70,8 @@ class Tuner(Recoverable):
By default `receive_trial_result()` will only receive results of algorithm-generated hyper-parameters.
If tuners want to receive those of customized parameters as well, they can call this function in `__init__()`.
"""
# pylint: disable=attribute-defined-outside-init
# FIXME: because tuner is designed as interface, this API should not be here
self._accept_customized = accept
def trial_end(self, parameter_id, success, **kwargs):
......@@ -78,7 +79,6 @@ class Tuner(Recoverable):
parameter_id: int
success: True if the trial successfully completed; False if failed or terminated
"""
pass
def update_search_space(self, search_space):
"""Update the search space of tuner. Must override.
......@@ -91,20 +91,19 @@ class Tuner(Recoverable):
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
_logger.info('Load checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def save_checkpoint(self):
"""Save the checkpoint of tuner.
path: checkpoint directory for tuner
"""
checkpoin_path = self.get_checkpoint_path()
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s' % checkpoin_path)
_logger.info('Save checkpoint ignored by tuner, checkpoint path: %s', checkpoin_path)
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def _on_exit(self):
pass
......
......@@ -84,12 +84,13 @@ def extract_scalar_reward(value, scalar_key='default'):
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
if isinstance(value, float) or isinstance(value, int):
if isinstance(value, (float, int)):
reward = value
elif isinstance(value, dict) and scalar_key in value and isinstance(value[scalar_key], (float, int)):
reward = value[scalar_key]
else:
raise RuntimeError('Incorrect final result: the final result should be float/int, or a dict which has a key named "default" whose value is float/int.')
raise RuntimeError('Incorrect final result: the final result should be float/int, ' \
'or a dict which has a key named "default" whose value is float/int.')
return reward
......@@ -101,7 +102,6 @@ def convert_dict2tuple(value):
for _keys in value:
value[_keys] = convert_dict2tuple(value[_keys])
return tuple(sorted(value.items()))
else:
return value
......
# This pylintrc file is a little more strick than the one in root of code directory
# SDK source MUST pass lint rules in top level directory, and SHOULD pass rules here
[SETTINGS]
max-line-length=140
disable =
missing-docstring,
invalid-name, # C0103
no-member, # E1101: sometimes pylint cannot detect members correctly due to a bug
c-extension-no-member, # I1101
no-self-use, # R0201: many functions in this SDK are designed for override
duplicate-code, # R0801
too-many-instance-attributes, # R0902
too-few-public-methods, # R0903
too-many-public-methods, # R0904
too-many-return-statements, # R0911
too-many-branches, # R0912
too-many-arguments, # R0913
too-many-locals, # R0914
too-many-statements, # R0915
too-many-nested-blocks, # R1702
no-else-return, # R1705
chained-comparison, # R1716
no-else-raise, # R1720
protected-access, # W0212: underscore variables may be protected by whole SDK instead of single module
arguments-differ, # W0221: pylint cannot handle *args and **kwargs
super-init-not-called, # W0231: some interface classes do not expect users to call init
useless-super-delegation, # W0235: derived init may have different docstring
global-statement, # W0603: globals are useful to hide SDK internal states from user
unused-argument, # W0613: many functions in this SDK are designed for override
broad-except, # W0703: the SDK commonly catch exceptions to report error
fixme, # W0511
ignore-patterns=test.*.py
# List of members which are set dynamically and missed by pylint inference
generated-members=numpy.*,torch.*
......@@ -100,21 +100,21 @@ class TorchMnist(torch.nn.Module):
class CompressorTestCase(TestCase):
def test_tf_pruner(self):
model = TfMnist()
configure_list = [{'sparsity': 0.8, 'op_types': 'default'}]
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(configure_list).compress_default_graph()
def test_tf_quantizer(self):
model = TfMnist()
tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph()
tf_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress_default_graph()
def test_torch_pruner(self):
model = TorchMnist()
configure_list = [{'sparsity': 0.8, 'op_types': 'default'}]
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(configure_list).compress(model)
def test_torch_quantizer(self):
model = TorchMnist()
torch_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress(model)
torch_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress(model)
if __name__ == '__main__':
......
......@@ -33,7 +33,7 @@ class App extends React.Component<{}, AppState> {
}
changeInterval = (interval: number) => {
this.setState({ interval: interval });
this.setState({ interval });
if (this.timerId === null && interval !== 0) {
window.setTimeout(this.refresh);
} else if (this.timerId !== null && interval === 0) {
......
......@@ -154,4 +154,3 @@ def _generate_specific_file(src_path, dst_path, exp_id, trial_id, module):
raise RuntimeError(src_path + ' ' + '\n'.join(str(arg) for arg in exc.args))
else:
raise RuntimeError('Failed to expand annotations for %s: %r' % (src_path, exc))
......@@ -20,9 +20,10 @@
import ast
import astor
import numbers
import astor
# pylint: disable=unidiomatic-typecheck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment