Commit 9f40659d authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Fix a few issues related to fixed arc and from-tuner arc (#1876)

parent db91e8e6
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import nni import nni
from nni.env_vars import trial_env_vars from nni.env_vars import trial_env_vars
from nni.nas.pytorch.mutables import LayerChoice, InputChoice from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
from nni.nas.pytorch.mutator import Mutator from nni.nas.pytorch.mutator import Mutator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -104,10 +104,11 @@ class ClassicMutator(Mutator): ...@@ -104,10 +104,11 @@ class ClassicMutator(Mutator):
search_space_item : list search_space_item : list
The list for corresponding search space. The list for corresponding search space.
""" """
candidate_repr = search_space_item["candidates"]
multihot_list = [False] * mutable.n_candidates multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value): for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and search_space_item[i] == v, \ assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, search_space_item, v) "Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx) assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True multihot_list[i] = True
return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable return torch.tensor(multihot_list, dtype=torch.bool) # pylint: disable=not-callable
...@@ -121,17 +122,20 @@ class ClassicMutator(Mutator): ...@@ -121,17 +122,20 @@ class ClassicMutator(Mutator):
self._chosen_arch.keys()) self._chosen_arch.keys())
result = dict() result = dict()
for mutable in self.mutables: for mutable in self.mutables:
assert mutable.key in self._chosen_arch, "Expected '{}' in chosen arch, but not found.".format(mutable.key) if isinstance(mutable, (LayerChoice, InputChoice)):
assert mutable.key in self._chosen_arch, \
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key] data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \ assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data) "'{}' is not a valid choice.".format(data)
value = data["_value"]
idx = data["_idx"]
search_space_item = self._search_space[mutable.key]["_value"]
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, idx, value, search_space_item) result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, InputChoice): elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, idx, value, search_space_item) result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
else: else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result return result
...@@ -190,6 +194,8 @@ class ClassicMutator(Mutator): ...@@ -190,6 +194,8 @@ class ClassicMutator(Mutator):
search_space[key] = {"_type": INPUT_CHOICE, search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from, "_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}} "n_chosen": mutable.n_chosen}}
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
else: else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable)) raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space return search_space
......
...@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator): ...@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator):
return self._fixed_arc return self._fixed_arc
def _encode_tensor(data, device): def _encode_tensor(data):
if isinstance(data, list): if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)): if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
else: else:
return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict): if isinstance(data, dict):
return {k: _encode_tensor(v, device) for k, v in data.items()} return {k: _encode_tensor(v) for k, v in data.items()}
return data return data
def apply_fixed_architecture(model, fixed_arc_path, device=None): def apply_fixed_architecture(model, fixed_arc_path):
""" """
Load architecture from `fixed_arc_path` and apply to model. Load architecture from `fixed_arc_path` and apply to model.
...@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None): ...@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
Model with mutables. Model with mutables.
fixed_arc_path : str fixed_arc_path : str
Path to the JSON that stores the architecture. Path to the JSON that stores the architecture.
device : torch.device
Architecture weights will be transfered to `device`.
Returns Returns
------- -------
FixedArchitecture FixedArchitecture
""" """
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(fixed_arc_path, str): if isinstance(fixed_arc_path, str):
with open(fixed_arc_path, "r") as f: with open(fixed_arc_path, "r") as f:
fixed_arc = json.load(f) fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc, device) fixed_arc = _encode_tensor(fixed_arc)
architecture = FixedArchitecture(model, fixed_arc) architecture = FixedArchitecture(model, fixed_arc)
architecture.to(device)
architecture.reset() architecture.reset()
return architecture return architecture
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment