Unverified Commit 867871b2 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 1) - move files (#5020)

parent 481aa292
......@@ -3,21 +3,17 @@
import copy
import warnings
from collections import OrderedDict
from typing import Callable, List, Dict, Union, Tuple, Optional
from typing import Callable, List, Union, Tuple, Optional
import torch
import torch.nn as nn
from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from nni.nas.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice, ValueChoice, ValueChoiceX, ChoiceOf
from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
from .choice import ValueChoice, ValueChoiceX, ChoiceOf
from .mutation_utils import Mutable, get_fixed_value
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']
__all__ = ['Repeat']
class Repeat(Mutable):
......@@ -159,77 +155,3 @@ class Repeat(Mutable):
def __len__(self):
return self.max_depth
class NasBench201Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-201.
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.layers = nn.ModuleList()
self.in_features = in_features
self.out_features = out_features
self.num_tensors = num_tensors
op_candidates = self._make_dict(op_candidates)
for tid in range(1, num_tensors):
node_ops = nn.ModuleList()
for j in range(tid):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors: List[torch.Tensor] = [inputs]
for layer in self.layers:
current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) # type: ignore
tensors.append(torch.sum(torch.stack(current_tensor), 0))
return tensors[-1]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
......@@ -29,101 +28,80 @@ class StackedLSTMCell(nn.Module):
return next_h, next_c
class EnasMutator(Mutator):
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A mutator that mutates the graph with RL.
A controller that mutates the graph with RL.
Parameters
----------
model : nn.Module
PyTorch model.
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
super().__init__(model)
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
self.max_layer_choice = 0
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def sample_search(self):
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
self._initialize()
self._sample(self.mutables)
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._choices = dict()
self._anchors_hid = dict()
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
......@@ -131,67 +109,42 @@ class EnasMutator(Mutator):
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def _sample_layer_choice(self, mutable):
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft(self._h[-1])
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
return skip.bool()
if len(sampled) == 1:
sampled = sampled[0]
return sampled
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment