Unverified Commit bf7daa8f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Prettify the export format of NAS trainer (#2389)

parent af800213
...@@ -156,12 +156,23 @@ model = Net() ...@@ -156,12 +156,23 @@ model = Net()
apply_fixed_architecture(model, "model_dir/final_architecture.json") apply_fixed_architecture(model, "model_dir/final_architecture.json")
``` ```
The JSON is simply a mapping from mutable keys to one-hot or multi-hot representation of choices. For example The JSON is simply a mapping from mutable keys to choices. Choices can be expressed in:
* A string: select the candidate with corresponding name.
* A number: select the candidate with corresponding index.
* A list of string: select the candidates with corresponding names.
* A list of number: select the candidates with corresponding indices.
* A list of boolean values: a multi-hot array.
For example,
```json ```json
{ {
"LayerChoice1": [false, true, false, false], "LayerChoice1": "conv5x5",
"InputChoice2": [true, true, false] "LayerChoice2": 6,
"InputChoice3": ["layer1", "layer3"],
"InputChoice4": [1, 2],
"InputChoice5": [false, true, false, false, true]
} }
``` ```
......
...@@ -3,10 +3,9 @@ ...@@ -3,10 +3,9 @@
import json import json
import torch from .mutables import InputChoice, LayerChoice, MutableScope
from .mutator import Mutator
from nni.nas.pytorch.mutables import MutableScope from .utils import to_list
from nni.nas.pytorch.mutator import Mutator
class FixedArchitecture(Mutator): class FixedArchitecture(Mutator):
...@@ -17,8 +16,8 @@ class FixedArchitecture(Mutator): ...@@ -17,8 +16,8 @@ class FixedArchitecture(Mutator):
---------- ----------
model : nn.Module model : nn.Module
A mutable network. A mutable network.
fixed_arc : str or dict fixed_arc : dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict). Preloaded architecture object.
strict : bool strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once. Force everything that appears in ``fixed_arc`` to be used at least once.
""" """
...@@ -33,6 +32,34 @@ class FixedArchitecture(Mutator): ...@@ -33,6 +32,34 @@ class FixedArchitecture(Mutator):
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys)) raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
if mutable_keys - fixed_arc_keys: if mutable_keys - fixed_arc_keys:
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys)) raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc)
def _from_human_readable_architecture(self, human_arc):
# convert from an exported architecture
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc.
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
# which means {"op1": [0, ]} ir {"op1": ["conv", ]}
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()}
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
for mutable in self.mutables:
if mutable.key not in result_arc:
continue # skip silently
choice_arr = result_arc[mutable.key]
if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr):
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)):
# multihot, do nothing
continue
if isinstance(mutable, LayerChoice):
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(len(mutable))]
elif isinstance(mutable, InputChoice):
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)]
result_arc[mutable.key] = choice_arr
return result_arc
def sample_search(self): def sample_search(self):
""" """
...@@ -47,17 +74,6 @@ class FixedArchitecture(Mutator): ...@@ -47,17 +74,6 @@ class FixedArchitecture(Mutator):
return self._fixed_arc return self._fixed_arc
def _encode_tensor(data):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: _encode_tensor(v) for k, v in data.items()}
return data
def apply_fixed_architecture(model, fixed_arc): def apply_fixed_architecture(model, fixed_arc):
""" """
Load architecture from `fixed_arc` and apply to model. Load architecture from `fixed_arc` and apply to model.
...@@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc): ...@@ -78,7 +94,6 @@ def apply_fixed_architecture(model, fixed_arc):
if isinstance(fixed_arc, str): if isinstance(fixed_arc, str):
with open(fixed_arc) as f: with open(fixed_arc) as f:
fixed_arc = json.load(f) fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc)
architecture = FixedArchitecture(model, fixed_arc) architecture = FixedArchitecture(model, fixed_arc)
architecture.reset() architecture.reset()
return architecture return architecture
...@@ -7,7 +7,9 @@ from collections import defaultdict ...@@ -7,7 +7,9 @@ from collections import defaultdict
import numpy as np import numpy as np
import torch import torch
from nni.nas.pytorch.base_mutator import BaseMutator from .base_mutator import BaseMutator
from .mutables import LayerChoice, InputChoice
from .utils import to_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -58,7 +60,16 @@ class Mutator(BaseMutator): ...@@ -58,7 +60,16 @@ class Mutator(BaseMutator):
dict dict
A mapping from key of mutables to decisions. A mapping from key of mutables to decisions.
""" """
return self.sample_final() sampled = self.sample_final()
result = dict()
for mutable in self.mutables:
if not isinstance(mutable, (LayerChoice, InputChoice)):
# not supported as built-in
continue
result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key))
if sampled:
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys()))
return result
def status(self): def status(self):
""" """
...@@ -159,7 +170,7 @@ class Mutator(BaseMutator): ...@@ -159,7 +170,7 @@ class Mutator(BaseMutator):
mask = self._get_decision(mutable) mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \ assert len(mask) == len(mutable), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable)) "Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
out = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask) out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
return self._tensor_reduction(mutable.reduction, out), mask return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list): def on_forward_input_choice(self, mutable, tensor_list):
...@@ -185,17 +196,41 @@ class Mutator(BaseMutator): ...@@ -185,17 +196,41 @@ class Mutator(BaseMutator):
mask = self._get_decision(mutable) mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \ assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates) "Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask): def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type(): """
Select masked tensors and return a list of tensors.
Parameters
----------
map_fn : function
Convert candidates to target candidates. Can be simply identity.
candidates : list of torch.Tensor
Tensor list to apply the decision on.
mask : list-like object
Can be a list, an numpy array or a tensor (recommended). Needs to
have the same length as ``candidates``.
Returns
-------
tuple of list of torch.Tensor and torch.Tensor
Output and mask.
"""
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \
"BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type(): elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \
(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \
"FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m] out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
else: else:
raise ValueError("Unrecognized mask") raise ValueError("Unrecognized mask '%s'" % mask)
return out if not torch.is_tensor(mask):
mask = torch.tensor(mask) # pylint: disable=not-callable
return out, mask
def _tensor_reduction(self, reduction_type, tensor_list): def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none": if reduction_type == "none":
...@@ -237,3 +272,37 @@ class Mutator(BaseMutator): ...@@ -237,3 +272,37 @@ class Mutator(BaseMutator):
result = self._cache[mutable.key] result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result) logger.debug("Decision %s: %s", mutable.key, result)
return result return result
def _convert_mutable_decision_to_human_readable(self, mutable, sampled):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list = to_list(sampled)
converted = None
# If it's a boolean array, we can do optimization.
if all([t == 0 or t == 1 for t in multihot_list]):
if isinstance(mutable, LayerChoice):
assert len(multihot_list) == len(mutable), \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all modules have different names and they indeed have names
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names):
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if isinstance(mutable, InputChoice):
assert len(multihot_list) == mutable.n_candidates, \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all input candidates have different names
if len(set(mutable.choose_from)) == mutable.n_candidates:
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if converted is not None:
# if only one element, then remove the bracket
if len(converted) == 1:
converted = converted[0]
else:
# do nothing
converted = multihot_list
return converted
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import torch import torch
_counter = 0 _counter = 0
...@@ -45,6 +46,16 @@ def to_device(obj, device): ...@@ -45,6 +46,16 @@ def to_device(obj, device):
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj))) raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
def to_list(arr):
if torch.is_tensor(arr):
return arr.cpu().numpy().tolist()
if isinstance(arr, np.ndarray):
return arr.tolist()
if isinstance(arr, (list, tuple)):
return list(arr)
return arr
class AverageMeterGroup: class AverageMeterGroup:
""" """
Average meter group for multiple average meters. Average meter group for multiple average meters.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment