Unverified Commit c837bfc0 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Fix deepcopy of mutables (#4400)

parent 72087f8a
......@@ -10,14 +10,13 @@ import torch.nn as nn
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import NoContextError
from .utils import generate_new_label, get_fixed_value
from .utils import Mutable, generate_new_label, get_fixed_value
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
class LayerChoice(nn.Module):
class LayerChoice(Mutable):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
......@@ -60,16 +59,14 @@ class LayerChoice(nn.Module):
# FIXME: prior is designed but not supported yet
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
try:
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
return candidates[chosen]
except NoContextError:
return super().__new__(cls)
@classmethod
def create_fixed_module(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
label: Optional[str] = None, **kwargs):
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
return candidates[chosen]
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
......@@ -159,7 +156,7 @@ class LayerChoice(nn.Module):
return f'LayerChoice({self.candidates}, label={repr(self.label)})'
class InputChoice(nn.Module):
class InputChoice(Mutable):
"""
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys).
Use ``reduction`` to specify how chosen inputs are reduced into one output. A few options are:
......@@ -185,13 +182,10 @@ class InputChoice(nn.Module):
Identifier of the input choice.
"""
def __new__(cls, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: str = 'sum', *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
try:
return ChosenInputs(get_fixed_value(label), reduction=reduction)
except NoContextError:
return super().__new__(cls)
@classmethod
def create_fixed_module(cls, n_candidates: int, n_chosen: Optional[int] = 1, reduction: str = 'sum', *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
return ChosenInputs(get_fixed_value(label), reduction=reduction)
def __init__(self, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: str = 'sum', *,
......@@ -234,7 +228,7 @@ class InputChoice(nn.Module):
f'reduction={repr(self.reduction)}, label={repr(self.label)})'
class ValueChoice(Translatable, nn.Module):
class ValueChoice(Translatable, Mutable):
"""
ValueChoice is to choose one from ``candidates``.
......@@ -302,11 +296,9 @@ class ValueChoice(Translatable, nn.Module):
# FIXME: prior is designed but not supported yet
def __new__(cls, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
try:
return get_fixed_value(label)
except NoContextError:
return super().__new__(cls)
@classmethod
def create_fixed_module(cls, candidates: List[Any], *, label: Optional[str] = None, **kwargs):
return get_fixed_value(label)
def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__()
......
......@@ -9,14 +9,13 @@ from .api import LayerChoice, InputChoice
from .nn import ModuleList
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .utils import generate_new_label, get_fixed_value
from ...utils import NoContextError
from .utils import Mutable, generate_new_label, get_fixed_value
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']
class Repeat(nn.Module):
class Repeat(Mutable):
"""
Repeat a block by a variable number of times.
......@@ -25,23 +24,29 @@ class Repeat(nn.Module):
blocks : function, list of function, module or list of module
The block to be repeated. If not a list, it will be replicated into a list.
If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken.
If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied.
If a function, it will be called (the argument is the index) to instantiate a module.
Otherwise the module will be deep-copied.
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least `min` times and at most `max` times.
"""
def __new__(cls, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]],
depth: Union[int, Tuple[int, int]], label: Optional[str] = None):
try:
repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
except NoContextError:
return super().__new__(cls)
@classmethod
def create_fixed_module(cls,
blocks: Union[Callable[[int], nn.Module],
List[Callable[[int], nn.Module]],
nn.Module,
List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
def __init__(self,
blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]],
depth: Union[int, Tuple[int, int]], label: Optional[str] = None):
blocks: Union[Callable[[int], nn.Module],
List[Callable[[int], nn.Module]],
nn.Module,
List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.min_depth = depth if isinstance(depth, int) else depth[0]
......@@ -69,7 +74,7 @@ class Repeat(nn.Module):
assert repeat <= len(blocks), f'Not enough blocks to be used. {repeat} expected, only found {len(blocks)}.'
blocks = blocks[:repeat]
if not isinstance(blocks[0], nn.Module):
blocks = [b() for b in blocks]
blocks = [b(i) for i, b in enumerate(blocks)]
return blocks
......
......@@ -6,11 +6,10 @@ import numpy as np
import torch
import torch.nn as nn
from nni.retiarii.mutator import InvalidMutation, Mutator
from nni.retiarii.graph import Model
from .api import InputChoice, ValueChoice, LayerChoice
from .utils import generate_new_label, get_fixed_dict
from ...mutator import InvalidMutation, Mutator
from ...graph import Model
from ...utils import NoContextError
from .utils import Mutable, generate_new_label, get_fixed_dict
_logger = logging.getLogger(__name__)
......@@ -218,7 +217,7 @@ class _NasBench101CellFixed(nn.Module):
return outputs
class NasBench101Cell(nn.Module):
class NasBench101Cell(Mutable):
"""
Cell structure that is proposed in NAS-Bench-101 [nasbench101]_ .
......@@ -289,23 +288,21 @@ class NasBench101Cell(nn.Module):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __new__(cls, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
@classmethod
def create_fixed_module(cls, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
def make_list(x): return x if isinstance(x, list) else [x]
try:
label, selected = get_fixed_dict(label)
op_candidates = cls._make_dict(op_candidates)
num_nodes = selected[f'{label}/num_nodes']
adjacency_list = [make_list(selected[f'{label}/input{i}']) for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
return _NasBench101CellFixed(
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection)
except NoContextError:
return super().__new__(cls)
label, selected = get_fixed_dict(label)
op_candidates = cls._make_dict(op_candidates)
num_nodes = selected[f'{label}/num_nodes']
adjacency_list = [make_list(selected[f'{label}/input{i}']) for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
return _NasBench101CellFixed(
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection)
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
......
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, Union
from nni.retiarii.utils import ModelNamespace, get_current_context
import torch.nn as nn
from nni.retiarii.utils import NoContextError, ModelNamespace, get_current_context
class Mutable(nn.Module):
"""
This is just an implementation trick for now.
In future, this could be the base class for all PyTorch mutables including layer choice, input choice, etc.
This is not considered as an interface, but rather as a base class consisting of commonly used class/instance methods.
For API developers, it's not recommended to use ``isinstance(module, Mutable)`` to check for mutable modules either,
before the design is finalized.
"""
def __new__(cls, *args, **kwargs):
if not args and not kwargs:
# this can be the case of copy/deepcopy
# attributes are assigned afterwards in __dict__
return super().__new__(cls)
try:
return cls.create_fixed_module(*args, **kwargs)
except NoContextError:
return super().__new__(cls)
@classmethod
def create_fixed_module(cls, *args, **kwargs) -> Union[nn.Module, Any]:
"""
Try to create a fixed module from fixed dict.
If the code is running in a trial, this method would succeed, and a concrete module instead of a mutable will be created.
Raises no context error if the creation failed.
"""
raise NotImplementedError
def generate_new_label(label: Optional[str]):
......
......@@ -483,6 +483,54 @@ class GraphIR(unittest.TestCase):
self.assertTrue((self._get_converted_pytorch_model(model2)(torch.zeros(1, 16)) == 4).all())
self.assertTrue((self._get_converted_pytorch_model(model3)(torch.zeros(1, 16)) == 5).all())
def test_repeat_complex(self):
class AddOne(nn.Module):
def forward(self, x):
return x + 1
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.block = nn.Repeat(nn.LayerChoice([AddOne(), nn.Identity()], label='lc'), (3, 5), label='rep')
def forward(self, x):
return self.block(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2)
self.assertEqual(set([mutator.label for mutator in mutators]), {'lc', 'rep'})
sampler = RandomSampler()
for _ in range(10):
new_model = model
for mutator in mutators:
new_model = mutator.bind_sampler(sampler).apply(new_model)
result = self._get_converted_pytorch_model(new_model)(torch.zeros(1, 1)).item()
self.assertIn(result, [0., 3., 4., 5.])
# independent layer choice
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.block = nn.Repeat(lambda index: nn.LayerChoice([AddOne(), nn.Identity()]), (2, 3), label='rep')
def forward(self, x):
return self.block(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 4)
result = []
for _ in range(20):
new_model = model
for mutator in mutators:
new_model = mutator.bind_sampler(sampler).apply(new_model)
result.append(self._get_converted_pytorch_model(new_model)(torch.zeros(1, 1)).item())
self.assertIn(1., result)
def test_cell(self):
@self.get_serializer()
class Net(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment