Unverified Commit 2d8f925b authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Bug fix of Retiarii hyperparameter mutation (#4751)

parent c22dc0fc
......@@ -66,6 +66,12 @@ class Traceable:
"""
raise NotImplementedError()
def get(self) -> Any:
"""
Get the original object. Usually used together with ``trace_copy``.
"""
raise NotImplementedError()
class Translatable(abc.ABC):
"""
......@@ -136,6 +142,13 @@ class SerializableObject(Traceable):
{k: copy.copy(v) for k, v in self.trace_kwargs.items()},
)
def get(self) -> T:
if not self._get_nni_attr('call_super'):
# Reinitialize
return trace(self.trace_symbol)(*self.trace_args, **self.trace_kwargs)
return self
@property
def trace_symbol(self) -> Any:
return self._get_nni_attr('symbol')
......@@ -202,11 +215,15 @@ def _make_class_traceable(cls: T, create_wrapper: bool = False) -> T:
{k: copy.copy(v) for k, v in self.trace_kwargs.items()},
)
def get(self):
return self
attributes = {
'trace_symbol': property(getter_factory('symbol'), setter_factory('symbol')),
'trace_args': property(getter_factory('args'), setter_factory('args')),
'trace_kwargs': property(getter_factory('kwargs'), setter_factory('kwargs')),
'trace_copy': trace_copy
'trace_copy': trace_copy,
'get': get,
}
if not create_wrapper:
......@@ -562,13 +579,13 @@ class _pickling_object:
# Used in `_trace_cls`.
def __new__(cls, type_, kw_only, data):
type_ = cloudpickle.loads(type_)
type_ = _wrapped_cloudpickle_loads(type_)
# Restore the trace type
type_ = _trace_cls(type_, kw_only)
# restore type
if '_nni_symbol' in data:
data['_nni_symbol'] = cloudpickle.loads(data['_nni_symbol'])
data['_nni_symbol'] = _wrapped_cloudpickle_loads(data['_nni_symbol'])
# https://docs.python.org/3/library/pickle.html#pickling-class-instances
obj = type_.__new__(type_)
......@@ -674,7 +691,7 @@ def _formulate_arguments(func, args, kwargs, kw_only, is_class_init=False):
def _is_function(obj: Any) -> bool:
# https://stackoverflow.com/questions/624926/how-do-i-detect-whether-a-python-variable-is-a-function
return isinstance(obj, (types.FunctionType, types.BuiltinFunctionType, types.MethodType,
types.BuiltinMethodType))
types.BuiltinMethodType)) and obj is not None
def _import_cls_or_func_from_name(target: str) -> Any:
......@@ -727,7 +744,7 @@ def get_hybrid_cls_or_func_name(cls_or_func: Any, pickle_size_limit: int = 4096)
def import_cls_or_func_from_hybrid_name(s: str) -> Any:
if s.startswith('bytes:'):
b = base64.b64decode(s.split(':', 1)[-1])
return cloudpickle.loads(b)
return _wrapped_cloudpickle_loads(b)
if s.startswith('path:'):
s = s.split(':', 1)[-1]
return _import_cls_or_func_from_name(s)
......@@ -800,5 +817,14 @@ def _json_tricks_any_object_decode(obj: Dict[str, Any]) -> Any:
if isinstance(obj, dict) and '__nni_obj__' in obj:
obj = obj['__nni_obj__']
b = base64.b64decode(obj)
return cloudpickle.loads(b)
return _wrapped_cloudpickle_loads(b)
return obj
def _wrapped_cloudpickle_loads(b: bytes) -> Any:
try:
return cloudpickle.loads(b)
except TypeError:
warnings.warn('TypeError encountered during deserializing object. This could be caused by '
'inconsistency between Python versions where dump and load happens.')
raise
......@@ -21,7 +21,9 @@ def set_execution_engine(engine: AbstractExecutionEngine) -> None:
if _execution_engine is None:
_execution_engine = engine
else:
raise RuntimeError('Execution engine is already set.')
raise RuntimeError('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.')
def get_execution_engine() -> AbstractExecutionEngine:
......
......@@ -364,27 +364,27 @@ class EvaluatorValueChoiceMutator(Mutator):
if not is_traceable(obj):
return obj
if not any(isinstance(value, ValueChoiceX) for value in obj.trace_kwargs.values()):
# No valuechoice, not interesting
return obj
# Make a copy
obj = obj.trace_copy()
result = {}
updates = {}
# For each argument that is a composition of value choice
# we find all the leaf-value-choice in the mutation
# and compute the final result
# and compute the final updates
for key, param in obj.trace_kwargs.items():
if isinstance(param, ValueChoiceX):
leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()]
result[key] = param.evaluate(leaf_node_values)
updates[key] = param.evaluate(leaf_node_values)
elif is_traceable(param):
# Recursively
result[key] = self._mutate_traceable_object(param, value_choice_decisions)
sub_update = self._mutate_traceable_object(param, value_choice_decisions)
if sub_update is not param: # if mutated
updates[key] = sub_update
if updates:
mutated_obj = obj.trace_copy() # Make a copy
mutated_obj.trace_kwargs.update(updates) # Mutate
mutated_obj = mutated_obj.get() # Instantiate the full mutated object
obj.trace_kwargs.update(result)
return mutated_obj
return obj
......
......@@ -6,7 +6,9 @@ from collections import Counter
import pytest
import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.nn.pytorch as nn
import pytorch_lightning
import torch
import torch.nn.functional as F
from nni.retiarii import InvalidMutation, Sampler, basic_unit
......@@ -1202,11 +1204,31 @@ class Shared(unittest.TestCase):
samplers = [RandomSampler() for _ in range(3)]
for _ in range(10):
model = _apply_all_mutators(init_model, mutators, samplers)
a, v = model.evaluator.trace_kwargs['t'].trace_kwargs['a'], model.evaluator.trace_kwargs['v']
a, v = model.evaluator.trace_kwargs['t'].a, model.evaluator.trace_kwargs['v']
assert v % 10 == a
assert a in [1, 2, 3]
assert v // 10 in [1, 2, 3]
@unittest.skipIf(pytorch_lightning.__version__ < '1.0', 'Legacy PyTorch-lightning not supported')
def test_valuechoice_lightning(self):
@nni.trace
class AnyModule(pl.LightningModule):
pass
evaluator = pl.Lightning(AnyModule(), pl.Trainer(max_epochs=nn.ValueChoice([1, 2, 3])))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 2
init_model = Model(_internal=True)
init_model.evaluator = evaluator
samplers = [RandomSampler() for _ in range(2)]
values = []
for _ in range(20):
model = _apply_all_mutators(init_model, mutators, samplers)
values.append(model.evaluator.trainer.max_epochs)
model._dump()
assert len(set(values)) == 3
def test_retiarii_nn_import(self):
dummy = torch.zeros(1, 16, 32, 24)
nn.init.uniform_(dummy)
......
......@@ -353,3 +353,25 @@ def test_subclass():
assert obj.trace_kwargs == {'c': 1, 'd': 2}
assert issubclass(type(obj), Super)
assert isinstance(obj, Super)
def test_get():
@nni.trace
class Foo:
def __init__(self, a = 1):
self._a = a
def bar(self):
return self._a + 1
obj = Foo(3)
assert nni.load(nni.dump(obj)).bar() == 4
obj1 = obj.trace_copy()
with pytest.raises(AttributeError):
obj1.bar()
obj1.trace_kwargs['a'] = 5
obj1 = obj1.get()
assert obj1.bar() == 6
obj2 = obj1.trace_copy()
obj2.trace_kwargs['a'] = -1
assert obj2.get().bar() == 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment