Unverified Commit 867871b2 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Promote Retiarii to NAS (step 1) - move files (#5020)

parent 481aa292
...@@ -3,21 +3,17 @@ ...@@ -3,21 +3,17 @@
import copy import copy
import warnings import warnings
from collections import OrderedDict from typing import Callable, List, Union, Tuple, Optional
from typing import Callable, List, Dict, Union, Tuple, Optional
import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL from nni.nas.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice, ValueChoice, ValueChoiceX, ChoiceOf from .choice import ValueChoice, ValueChoiceX, ChoiceOf
from .cell import Cell from .mutation_utils import Mutable, get_fixed_value
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell'] __all__ = ['Repeat']
class Repeat(Mutable): class Repeat(Mutable):
...@@ -159,77 +155,3 @@ class Repeat(Mutable): ...@@ -159,77 +155,3 @@ class Repeat(Mutable):
def __len__(self): def __len__(self):
return self.max_depth return self.max_depth
class NasBench201Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-201.
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.layers = nn.ModuleList()
self.in_features = in_features
self.out_features = out_features
self.num_tensors = num_tensors
op_candidates = self._make_dict(op_candidates)
for tid in range(1, num_tensors):
node_ops = nn.ModuleList()
for j in range(tid):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors: List[torch.Tensor] = [inputs]
for layer in self.layers:
current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) # type: ignore
tensors.append(torch.sum(torch.stack(current_tensor), 0))
return tensors[-1]
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module): class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias): def __init__(self, layers, size, bias):
...@@ -29,101 +28,80 @@ class StackedLSTMCell(nn.Module): ...@@ -29,101 +28,80 @@ class StackedLSTMCell(nn.Module):
return next_h, next_c return next_h, next_c
class EnasMutator(Mutator): class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
""" """
A mutator that mutates the graph with RL. A controller that mutates the graph with RL.
Parameters Parameters
---------- ----------
model : nn.Module fields : list of ReinforceField
PyTorch model. List of fields to choose.
lstm_size : int lstm_size : int
Controller LSTM hidden units. Controller LSTM hidden units.
lstm_num_layers : int lstm_num_layers : int
Number of layers for stacked LSTM. Number of layers for stacked LSTM.
tanh_constant : float tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``. Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float skip_target : float
Target probability that skipconnect will appear. Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float temperature : float
Temperature constant that divides the logits. Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced. Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
""" """
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"): skip_target=0.4, temperature=None, entropy_reduction='sum'):
super().__init__(model) super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant self.tanh_constant = tanh_constant
self.temperature = temperature self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target self.skip_target = skip_target
self.branch_bias = branch_bias
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean." requires_grad=False)
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.bias_dict = nn.ParameterDict() self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
self.max_layer_choice = 0 field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
for mutable in self.mutables: })
if isinstance(mutable, LayerChoice): self.embedding = nn.ModuleDict({
if self.max_layer_choice == 0: field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
self.max_layer_choice = len(mutable) })
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates." def resample(self):
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def sample_search(self):
self._initialize() self._initialize()
self._sample(self.mutables) result = dict()
return self._choices for field in self.fields:
result[field.name] = self._sample_single(field)
def sample_final(self): return result
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
def _initialize(self): def _initialize(self):
self._choices = dict()
self._anchors_hid = dict()
self._inputs = self.g_emb.data self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size), self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype, dtype=self._inputs.dtype,
...@@ -131,67 +109,42 @@ class EnasMutator(Mutator): ...@@ -131,67 +109,42 @@ class EnasMutator(Mutator):
self._h = [torch.zeros((1, self.lstm_size), self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype, dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)] device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0 self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy = 0 self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty = 0 self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self): def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c)) self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _mark_anchor(self, key): def _sample_single(self, field):
self._anchors_hid[key] = self._h[-1]
def _sample_layer_choice(self, mutable):
self._lstm_next_step() self._lstm_next_step()
logit = self.soft(self._h[-1]) logit = self.soft[field.name](self._h[-1])
if self.temperature is not None: if self.temperature is not None:
logit /= self.temperature logit /= self.temperature
if self.tanh_constant is not None: if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit) logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict: if field.choose_one:
logit += self.bias_dict[mutable.key] sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, sampled)
log_prob = self.cross_entropy_loss(logit, branch_id) self._inputs = self.embedding[field.name](sampled)
self.sample_log_prob += self.entropy_reduction(log_prob) else:
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type logit = logit.view(-1, 1)
self.sample_entropy += self.entropy_reduction(entropy) logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
self._inputs = self.embedding(branch_id) sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit) skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip) log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) sampled = sampled.nonzero().view(-1)
else: if sampled.sum().item():
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
logit = query.view(1, -1) else:
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index) sampled = sampled.detach().cpu().numpy().tolist()
self._inputs = anchors[index.item()]
self.sample_log_prob += self.entropy_reduction(log_prob) self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy) self.sample_entropy += self.entropy_reduction(entropy)
return skip.bool() if len(sampled) == 1:
sampled = sampled[0]
return sampled
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment