Unverified Commit 867871b2 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 1) - move files (#5020)

parent 481aa292
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
import numpy as np
import torch
_counter = 0
_logger = logging.getLogger(__name__)
def global_mutable_counting():
"""
A program level counter starting from 1.
"""
global _counter
_counter += 1
return _counter
def _reset_global_mutable_counting():
"""
Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys.
"""
global _counter
_counter = 0
def to_device(obj, device):
"""
Move a tensor, tuple, list, or dict onto device.
"""
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, tuple):
return tuple(to_device(t, device) for t in obj)
if isinstance(obj, list):
return [to_device(t, device) for t in obj]
if isinstance(obj, dict):
return {k: to_device(v, device) for k, v in obj.items()}
if isinstance(obj, (int, float, str)):
return obj
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
def to_list(arr):
if torch.is_tensor(arr):
return arr.cpu().numpy().tolist()
if isinstance(arr, np.ndarray):
return arr.tolist()
if isinstance(arr, (list, tuple)):
return list(arr)
return arr
class AverageMeterGroup:
"""
Average meter group for multiple average meters.
"""
def __init__(self):
self.meters = OrderedDict()
def update(self, data):
"""
Update the meter group with a dict of metrics.
Non-exist average meters will be automatically created.
"""
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f")
self.meters[k].update(v)
def __getattr__(self, item):
return self.meters[item]
def __getitem__(self, item):
return self.meters[item]
def __str__(self):
return " ".join(str(v) for v in self.meters.values())
def summary(self):
"""
Return a summary string of group data.
"""
return " ".join(v.summary() for v in self.meters.values())
class AverageMeter:
"""
Computes and stores the average and current value.
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
"""
Reset the meter.
"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""
Update with value and weight.
Parameters
----------
val : float or int
The new value to be accounted in.
n : int
The weight of the new value.
"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = '{name}: {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class StructuredMutableTreeNode:
"""
A structured representation of a search space.
A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`.
This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet,
the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a
``Mutable`` (other than ``MutableScope``).
Parameters
----------
mutable : nni.nas.pytorch.mutables.Mutable
The mutable that current node is linked with.
"""
def __init__(self, mutable):
self.mutable = mutable
self.children = []
def add_child(self, mutable):
"""
Add a tree node to the children list of current node.
"""
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]
def type(self):
"""
Return the ``type`` of mutable content.
"""
return type(self.mutable)
def __iter__(self):
return self.traverse()
def traverse(self, order="pre", deduplicate=True, memo=None):
"""
Return a generator that generates a list of mutables in this tree.
Parameters
----------
order : str
pre or post. If pre, current mutable is yield before children. Otherwise after.
deduplicate : bool
If true, mutables with the same key will not appear after the first appearance.
memo : dict
An auxiliary dict that memorize keys seen before, so that deduplication is possible.
Returns
-------
generator of Mutable
"""
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from tensorflow.keras import Model
from .mutables import Mutable, MutableScope, InputChoice
from .utils import StructuredMutableTreeNode
class BaseMutator(Model):
def __init__(self, model):
super().__init__()
self.__dict__['model'] = model
self._structured_mutables = self._parse_search_space(self.model)
def _parse_search_space(self, module, root=None, prefix='', memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError('Cannot have nested search space. Error at {} in {}'
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError('"{}" required by "{}" not found in keys that appeared before, and is not NO_KEY.'
.format(k, module.key))
for submodule in module.layers:
if not isinstance(submodule, Model):
continue
submodule_prefix = prefix + ('.' if prefix else '') + submodule.name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo, nested_detection=nested_detection)
return root
@property
def mutables(self):
return self._structured_mutables
def undedup_mutables(self):
return self._structured_mutables.traverse(deduplicate=False)
def call(self, *inputs):
raise RuntimeError('Call is undefined for mutators.')
def __setattr__(self, name, value):
if name == 'model':
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include your network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
pass
def on_forward_layer_choice(self, mutable, *inputs):
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list):
raise NotImplementedError
def export(self):
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
from tensorflow.keras import Model
from .utils import global_mutable_counting
_logger = logging.getLogger(__name__)
class Mutable(Model):
def __init__(self, key=None):
super().__init__()
if key is None:
self._key = '{}_{}'.format(type(self).__name__, global_mutable_counting())
elif isinstance(key, str):
self._key = key
else:
self._key = str(key)
_logger.warning('Key "%s" is not string, converted to string.', key)
self.init_hook = None
self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def set_mutator(self, mutator):
if hasattr(self, 'mutator'):
raise RuntimeError('`set_mutator is called more than once. '
'Did you parse the search space multiple times? '
'Or did you apply multiple fixed architectures?')
self.mutator = mutator
def call(self, *inputs):
raise NotImplementedError('Method `call` of Mutable must be overridden')
def build(self, input_shape):
self._check_built()
@property
def key(self):
return self._key
@property
def name(self):
return self._name if hasattr(self, '_name') else self._key
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, 'mutator'):
raise ValueError(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return '{} ({})'.format(self.name, self.key)
class MutableScope(Mutable):
def __call__(self, *args, **kwargs):
try:
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
self.names = []
if isinstance(op_candidates, OrderedDict):
for name in op_candidates:
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.names.append(name)
elif isinstance(op_candidates, list):
for i, _ in enumerate(op_candidates):
self.names.append(str(i))
else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.length = len(op_candidates)
self.choices = op_candidates
self.reduction = reduction
self.return_mask = return_mask
def call(self, *inputs):
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
def build(self, input_shape):
self._check_built()
for op in self.choices:
op.build(input_shape)
def __len__(self):
return len(self.choices)
class InputChoice(Mutable):
NO_KEY = ''
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None, reduction='sum', return_mask=False, key=None):
super().__init__(key=key)
assert n_candidates is not None or choose_from is not None, \
'At least one of `n_candidates` and `choose_from` must be not None.'
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), 'Number of candidates must be equal to the length of `choose_from`.'
assert n_candidates > 0, 'Number of candidates must be greater than 0.'
assert n_chosen is None or 0 <= n_chosen <= n_candidates, \
'Expected selected number must be None or no more than number of candidates.'
self.n_candidates = n_candidates
self.choose_from = choose_from.copy()
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def call(self, optional_inputs):
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
'Optional input list must be a list, not a {}.'.format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
'Length of the input list must be equal to number of candidates.'
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import tensorflow as tf
from .base_mutator import BaseMutator
_logger = logging.getLogger(__name__)
class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = {}
def sample_search(self):
raise NotImplementedError('Method `sample_search` must be overridden')
def sample_final(self):
raise NotImplementedError('Method `sample_final` must be overriden for exporting')
def reset(self):
self._cache = self.sample_search()
def export(self):
return self.sample_final()
# TODO: status
# TODO: graph
def on_forward_layer_choice(self, mutable, *inputs):
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
'Invalid mask, expected {} to be of length {}.'.format(mask, len(mutable))
out = self._select_with_mask(lambda choice: choice(*inputs), mutable.choices, mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
'Invalid mask, expected {} to be of length {}.'.format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda tensor: tensor, tensor_list, mask)
return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask):
if mask.dtype.is_bool:
out = [map_fn(cand) for cand, m in zip(candidates, mask) if m]
elif mask.dtype.is_floating:
out = [map_fn(cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError('Unrecognized mask, dtype is {}'.format(mask.dtype.name))
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == 'none':
return tensor_list
if not tensor_list:
return None
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == 'sum':
return sum(tensor_list)
if reduction_type == 'mean':
return sum(tensor_list) / len(tensor_list)
if reduction_type == 'concat':
image_data_format = tf.keras.backend.image_data_format()
if image_data_format == "channels_first":
axis = 0
else:
axis = -1
return tf.concat(tensor_list, axis=axis) # pylint: disable=E1120,E1123
# pylint issue #3613
raise ValueError('Unrecognized reduction policy: "{}'.format(reduction_type))
def _get_decision(self, mutable):
if mutable.key not in self._cache:
raise ValueError('"{}" not found in decision cache.'.format(mutable.key))
result = self._cache[mutable.key]
_logger.debug('Decision %s: %s', mutable.key, result)
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
_counter = 0
def global_mutable_counting():
global _counter
_counter += 1
return _counter
class AverageMeter:
def __init__(self, name):
self.name = name
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val):
self.val = val
self.sum += val
self.count += 1
self.avg = self.sum / self.count
def __str__(self):
return '{name} {val:4f} ({avg:4f})'.format(**self.__dict__)
def summary(self):
return '{name}: {avg:4f}'.format(**self.__dict__)
class AverageMeterGroup:
def __init__(self):
self.meters = {}
def update(self, data):
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k)
self.meters[k].update(v)
def __str__(self):
return ' '.join(str(v) for v in self.meters.values())
def summary(self):
return ' '.join(v.summary() for v in self.meters.values())
class StructuredMutableTreeNode:
def __init__(self, mutable):
self.mutable = mutable
self.children = []
def add_child(self, mutable):
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]
def type(self):
return type(self.mutable)
def __iter__(self):
return self.traverse()
def traverse(self, order="pre", deduplicate=True, memo=None):
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
def fill_zero_grads(grads, weights):
ret = []
for grad, weight in zip(grads, weights):
if grad is not None:
ret.append(grad)
else:
ret.append(tf.zeros_like(weight))
return ret
...@@ -18,148 +18,6 @@ from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice ...@@ -18,148 +18,6 @@ from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
return sampled
class EnasTrainer(BaseOneShotTrainer): class EnasTrainer(BaseOneShotTrainer):
""" """
ENAS trainer. ENAS trainer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment