Unverified Commit 2d8f925b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Bug fix of Retiarii hyperparameter mutation (#4751)

parent c22dc0fc
...@@ -66,6 +66,12 @@ class Traceable: ...@@ -66,6 +66,12 @@ class Traceable:
""" """
raise NotImplementedError() raise NotImplementedError()
def get(self) -> Any:
"""
Get the original object. Usually used together with ``trace_copy``.
"""
raise NotImplementedError()
class Translatable(abc.ABC): class Translatable(abc.ABC):
""" """
...@@ -136,6 +142,13 @@ class SerializableObject(Traceable): ...@@ -136,6 +142,13 @@ class SerializableObject(Traceable):
{k: copy.copy(v) for k, v in self.trace_kwargs.items()}, {k: copy.copy(v) for k, v in self.trace_kwargs.items()},
) )
def get(self) -> T:
if not self._get_nni_attr('call_super'):
# Reinitialize
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
return self
@property @property
def trace_symbol(self) -> Any: def trace_symbol(self) -> Any:
return self._get_nni_attr('symbol') return self._get_nni_attr('symbol')
...@@ -202,11 +215,15 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T: ...@@ -202,11 +215,15 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
{k: copy.copy(v) for k, v in self.trace_kwargs.items()}, {k: copy.copy(v) for k, v in self.trace_kwargs.items()},
) )
def get(self):
return self
attributes = { attributes = {
'trace_symbol': property(getter_factory('symbol'), setter_factory('symbol')), 'trace_symbol': property(getter_factory('symbol'), setter_factory('symbol')),
'trace_args': property(getter_factory('args'), setter_factory('args')), 'trace_args': property(getter_factory('args'), setter_factory('args')),
'trace_kwargs': property(getter_factory('kwargs'), setter_factory('kwargs')), 'trace_kwargs': property(getter_factory('kwargs'), setter_factory('kwargs')),
'trace_copy': trace_copy 'trace_copy': trace_copy,
'get': get,
} }
if not create_wrapper: if not create_wrapper:
...@@ -562,13 +579,13 @@ class _pickling_object: ...@@ -562,13 +579,13 @@ class _pickling_object:
# Used in `_trace_cls`. # Used in `_trace_cls`.
def __new__(cls, type_, kw_only, data): def __new__(cls, type_, kw_only, data):
type_ = cloudpickle.loads(type_) type_ = _wrapped_cloudpickle_loads(type_)
# Restore the trace type # Restore the trace type
type_ = _trace_cls(type_, kw_only) type_ = _trace_cls(type_, kw_only)
# restore type # restore type
if '_nni_symbol' in data: if '_nni_symbol' in data:
data['_nni_symbol'] = cloudpickle.loads(data['_nni_symbol']) data['_nni_symbol'] = _wrapped_cloudpickle_loads(data['_nni_symbol'])
# https://docs.python.org/3/library/pickle.html#pickling-class-instances # https://docs.python.org/3/library/pickle.html#pickling-class-instances
obj = type_.__new__(type_) obj = type_.__new__(type_)
...@@ -674,7 +691,7 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False): ...@@ -674,7 +691,7 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
def _is_function(obj: Any) -> bool: def _is_function(obj: Any) -> bool:
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function # https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return isinstance(obj, (types.FunctionType, types.BuiltinFunctionType, types.MethodType, return isinstance(obj, (types.FunctionType, types.BuiltinFunctionType, types.MethodType,
types.BuiltinMethodType)) types.BuiltinMethodType)) and obj is not None
def _import_cls_or_func_from_name(target: str) -> Any: def _import_cls_or_func_from_name(target: str) -> Any:
...@@ -727,7 +744,7 @@ def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096) ...@@ -727,7 +744,7 @@ def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096)
def import_cls_or_func_from_hybrid_name(s: str) -> Any: def import_cls_or_func_from_hybrid_name(s: str) -> Any:
if s.startswith('bytes:'): if s.startswith('bytes:'):
b = base64.b64decode(s.split(':', 1)[-1]) b = base64.b64decode(s.split(':', 1)[-1])
return cloudpickle.loads(b) return _wrapped_cloudpickle_loads(b)
if s.startswith('path:'): if s.startswith('path:'):
s = s.split(':', 1)[-1] s = s.split(':', 1)[-1]
return _import_cls_or_func_from_name(s) return _import_cls_or_func_from_name(s)
...@@ -800,5 +817,14 @@ def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any: ...@@ -800,5 +817,14 @@ def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any:
if isinstance(obj, dict) and '__nni_obj__' in obj: if isinstance(obj, dict) and '__nni_obj__' in obj:
obj = obj['__nni_obj__'] obj = obj['__nni_obj__']
b = base64.b64decode(obj) b = base64.b64decode(obj)
return cloudpickle.loads(b) return _wrapped_cloudpickle_loads(b)
return obj return obj
def _wrapped_cloudpickle_loads(b: bytes) -> Any:
try:
return cloudpickle.loads(b)
except TypeError:
warnings.warn('TypeError encountered during deserializing object. This could be caused by '
'inconsistency between Python versions where dump and load happens.')
raise
...@@ -21,7 +21,9 @@ def set_execution_engine(engine: AbstractExecutionEngine) -> None: ...@@ -21,7 +21,9 @@ def set_execution_engine(engine: AbstractExecutionEngine) -> None:
if _execution_engine is None: if _execution_engine is None:
_execution_engine = engine _execution_engine = engine
else: else:
raise RuntimeError('Execution engine is already set.') raise RuntimeError('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.')
def get_execution_engine() -> AbstractExecutionEngine: def get_execution_engine() -> AbstractExecutionEngine:
......
...@@ -364,27 +364,27 @@ class EvaluatorValueChoiceMutator(Mutator): ...@@ -364,27 +364,27 @@ class EvaluatorValueChoiceMutator(Mutator):
if not is_traceable(obj): if not is_traceable(obj):
return obj return obj
if not any(isinstance(value, ValueChoiceX) for value in obj.trace_kwargs.values()): updates = {}
# No valuechoice, not interesting
return obj
# Make a copy
obj = obj.trace_copy()
result = {}
# For each argument that is a composition of value choice # For each argument that is a composition of value choice
# we find all the leaf-value-choice in the mutation # we find all the leaf-value-choice in the mutation
# and compute the final result # and compute the final updates
for key, param in obj.trace_kwargs.items(): for key, param in obj.trace_kwargs.items():
if isinstance(param, ValueChoiceX): if isinstance(param, ValueChoiceX):
leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()] leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()]
result[key] = param.evaluate(leaf_node_values) updates[key] = param.evaluate(leaf_node_values)
elif is_traceable(param): elif is_traceable(param):
# Recursively # Recursively
result[key] = self._mutate_traceable_object(param, value_choice_decisions) sub_update = self._mutate_traceable_object(param, value_choice_decisions)
if sub_update is not param: # if mutated
updates[key] = sub_update
if updates:
mutated_obj = obj.trace_copy() # Make a copy
mutated_obj.trace_kwargs.update(updates) # Mutate
mutated_obj = mutated_obj.get() # Instantiate the full mutated object
obj.trace_kwargs.update(result) return mutated_obj
return obj return obj
......
...@@ -6,7 +6,9 @@ from collections import Counter ...@@ -6,7 +6,9 @@ from collections import Counter
import pytest import pytest
import nni import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import pytorch_lightning
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nni.retiarii import InvalidMutation, Sampler, basic_unit from nni.retiarii import InvalidMutation, Sampler, basic_unit
...@@ -1202,11 +1204,31 @@ class Shared(unittest.TestCase): ...@@ -1202,11 +1204,31 @@ class Shared(unittest.TestCase):
samplers = [RandomSampler() for _ in range(3)] samplers = [RandomSampler() for _ in range(3)]
for _ in range(10): for _ in range(10):
model = _apply_all_mutators(init_model, mutators, samplers) model = _apply_all_mutators(init_model, mutators, samplers)
a, v = model.evaluator.trace_kwargs['t'].trace_kwargs['a'], model.evaluator.trace_kwargs['v'] a, v = model.evaluator.trace_kwargs['t'].a, model.evaluator.trace_kwargs['v']
assert v % 10 == a assert v % 10 == a
assert a in [1, 2, 3] assert a in [1, 2, 3]
assert v // 10 in [1, 2, 3] assert v // 10 in [1, 2, 3]
@unittest.skipIf(pytorch_lightning.__version__ < '1.0', 'Legacy PyTorch-lightning not supported')
def test_valuechoice_lightning(self):
@nni.trace
class AnyModule(pl.LightningModule):
pass
evaluator = pl.Lightning(AnyModule(), pl.Trainer(max_epochs=nn.ValueChoice([1, 2, 3])))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 2
init_model = Model(_internal=True)
init_model.evaluator = evaluator
samplers = [RandomSampler() for _ in range(2)]
values = []
for _ in range(20):
model = _apply_all_mutators(init_model, mutators, samplers)
values.append(model.evaluator.trainer.max_epochs)
model._dump()
assert len(set(values)) == 3
def test_retiarii_nn_import(self): def test_retiarii_nn_import(self):
dummy = torch.zeros(1, 16, 32, 24) dummy = torch.zeros(1, 16, 32, 24)
nn.init.uniform_(dummy) nn.init.uniform_(dummy)
......
...@@ -353,3 +353,25 @@ def test_subclass(): ...@@ -353,3 +353,25 @@ def test_subclass():
assert obj.trace_kwargs == {'c': 1, 'd': 2} assert obj.trace_kwargs == {'c': 1, 'd': 2}
assert issubclass(type(obj), Super) assert issubclass(type(obj), Super)
assert isinstance(obj, Super) assert isinstance(obj, Super)
def test_get():
@nni.trace
class Foo:
def __init__(self, a = 1):
self._a = a
def bar(self):
return self._a + 1
obj = Foo(3)
assert nni.load(nni.dump(obj)).bar() == 4
obj1 = obj.trace_copy()
with pytest.raises(AttributeError):
obj1.bar()
obj1.trace_kwargs['a'] = 5
obj1 = obj1.get()
assert obj1.bar() == 6
obj2 = obj1.trace_copy()
obj2.trace_kwargs['a'] = -1
assert obj2.get().bar() == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment