Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .evolution import SPOSEvolution
from .mutator import SPOSSupernetTrainingMutator
from .trainer import SPOSSupernetTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import re
from collections import deque
import numpy as np
from nni.tuner import Tuner
from nni.algorithms.nas.pytorch.classic_nas.mutator import LAYER_CHOICE, INPUT_CHOICE
_logger = logging.getLogger(__name__)
class SPOSEvolution(Tuner):
"""
SPOS evolution tuner.
Parameters
----------
max_epochs : int
Maximum number of epochs to run.
num_select : int
Number of survival candidates of each epoch.
num_population : int
Number of candidates at the start of each epoch. If candidates generated by
crossover and mutation are not enough, the rest will be filled with random
candidates.
m_prob : float
The probability of mutation.
num_crossover : int
Number of candidates generated by crossover in each epoch.
num_mutation : int
Number of candidates generated by mutation in each epoch.
"""
def __init__(self, max_epochs=20, num_select=10, num_population=50, m_prob=0.1,
num_crossover=25, num_mutation=25):
assert num_population >= num_select
self.max_epochs = max_epochs
self.num_select = num_select
self.num_population = num_population
self.m_prob = m_prob
self.num_crossover = num_crossover
self.num_mutation = num_mutation
self.epoch = 0
self.candidates = []
self.search_space = None
self.random_state = np.random.RandomState(0)
# async status
self._to_evaluate_queue = deque()
self._sending_parameter_queue = deque()
self._pending_result_ids = set()
self._reward_dict = dict()
self._id2candidate = dict()
self._st_callback = None
def update_search_space(self, search_space):
"""
Handle the initialization/update event of search space.
"""
self._search_space = search_space
self._next_round()
def _next_round(self):
_logger.info("Epoch %d, generating...", self.epoch)
if self.epoch == 0:
self._get_random_population()
self.export_results(self.candidates)
else:
best_candidates = self._select_top_candidates()
self.export_results(best_candidates)
if self.epoch >= self.max_epochs:
return
self.candidates = self._get_mutation(best_candidates) + self._get_crossover(best_candidates)
self._get_random_population()
self.epoch += 1
def _random_candidate(self):
chosen_arch = dict()
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
index = self.random_state.randint(len(choices))
chosen_arch[key] = {"_value": choices[index], "_idx": index}
elif val["_type"] == INPUT_CHOICE:
raise NotImplementedError("Input choice is not implemented yet.")
return chosen_arch
def _add_to_evaluate_queue(self, cand):
_logger.info("Generate candidate %s, adding to eval queue.", self._get_architecture_repr(cand))
self._reward_dict[self._hashcode(cand)] = 0.
self._to_evaluate_queue.append(cand)
def _get_random_population(self):
while len(self.candidates) < self.num_population:
cand = self._random_candidate()
if self._is_legal(cand):
_logger.info("Random candidate generated.")
self._add_to_evaluate_queue(cand)
self.candidates.append(cand)
def _get_crossover(self, best):
result = []
for _ in range(10 * self.num_crossover):
cand_p1 = best[self.random_state.randint(len(best))]
cand_p2 = best[self.random_state.randint(len(best))]
assert cand_p1.keys() == cand_p2.keys()
cand = {k: cand_p1[k] if self.random_state.randint(2) == 0 else cand_p2[k]
for k in cand_p1.keys()}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_crossover:
break
_logger.info("Found %d architectures with crossover.", len(result))
return result
def _get_mutation(self, best):
result = []
for _ in range(10 * self.num_mutation):
cand = best[self.random_state.randint(len(best))].copy()
mutation_sample = np.random.random_sample(len(cand))
for s, k in zip(mutation_sample, cand):
if s < self.m_prob:
choices = self._search_space[k]["_value"]
index = self.random_state.randint(len(choices))
cand[k] = {"_value": choices[index], "_idx": index}
if self._is_legal(cand):
result.append(cand)
self._add_to_evaluate_queue(cand)
if len(result) >= self.num_mutation:
break
_logger.info("Found %d architectures with mutation.", len(result))
return result
def _get_architecture_repr(self, cand):
return re.sub(r"\".*?\": \{\"_idx\": (\d+), \"_value\": \".*?\"\}", r"\1",
self._hashcode(cand))
def _is_legal(self, cand):
if self._hashcode(cand) in self._reward_dict:
return False
return True
def _select_top_candidates(self):
reward_query = lambda cand: self._reward_dict[self._hashcode(cand)]
_logger.info("All candidate rewards: %s", list(map(reward_query, self.candidates)))
result = sorted(self.candidates, key=reward_query, reverse=True)[:self.num_select]
_logger.info("Best candidate rewards: %s", list(map(reward_query, result)))
return result
@staticmethod
def _hashcode(d):
return json.dumps(d, sort_keys=True)
def _bind_and_send_parameters(self):
"""
There are two types of resources: parameter ids and candidates. This function is called at
necessary times to bind these resources to send new trials with st_callback.
"""
result = []
while self._sending_parameter_queue and self._to_evaluate_queue:
parameter_id = self._sending_parameter_queue.popleft()
parameters = self._to_evaluate_queue.popleft()
self._id2candidate[parameter_id] = parameters
result.append(parameters)
self._pending_result_ids.add(parameter_id)
self._st_callback(parameter_id, parameters)
_logger.info("Send parameter [%d] %s.", parameter_id, self._get_architecture_repr(parameters))
return result
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
Callback function necessary to implement a tuner. This will put more parameter ids into the
parameter id queue.
"""
if "st_callback" in kwargs and self._st_callback is None:
self._st_callback = kwargs["st_callback"]
for parameter_id in parameter_id_list:
self._sending_parameter_queue.append(parameter_id)
self._bind_and_send_parameters()
return [] # always not use this. might induce problem of over-sending
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Callback function. Receive a trial result.
"""
_logger.info("Candidate %d, reported reward %f", parameter_id, value)
self._reward_dict[self._hashcode(self._id2candidate[parameter_id])] = value
def trial_end(self, parameter_id, success, **kwargs):
"""
Callback function when a trial is ended and resource is released.
"""
self._pending_result_ids.remove(parameter_id)
if not self._pending_result_ids and not self._to_evaluate_queue:
# a new epoch now
self._next_round()
assert self._st_callback is not None
self._bind_and_send_parameters()
def export_results(self, result):
"""
Export a number of candidates to `checkpoints` dir.
Parameters
----------
result : dict
Chosen architectures to be exported.
"""
os.makedirs("checkpoints", exist_ok=True)
for i, cand in enumerate(result):
converted = dict()
for cand_key, cand_val in cand.items():
onehot = [k == cand_val["_idx"] for k in range(len(self._search_space[cand_key]["_value"]))]
converted[cand_key] = onehot
with open(os.path.join("checkpoints", "%03d_%03d.json" % (self.epoch, i)), "w") as fp:
json.dump(converted, fp)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import numpy as np
from nni.algorithms.nas.pytorch.random import RandomMutator
_logger = logging.getLogger(__name__)
class SPOSSupernetTrainingMutator(RandomMutator):
"""
A random mutator with flops limit.
Parameters
----------
model : nn.Module
PyTorch model.
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
flops_lb : number
Lower bound of flops.
flops_ub : number
Upper bound of flops.
flops_bin_num : number
Number of bins divided for the interval of flops to ensure the uniformity. Bigger number will be more
uniform, but the sampling will be slower.
flops_sample_timeout : int
Maximum number of attempts to sample before giving up and use a random candidate.
"""
def __init__(self, model, flops_func=None, flops_lb=None, flops_ub=None,
flops_bin_num=7, flops_sample_timeout=500):
super().__init__(model)
self._flops_func = flops_func
if self._flops_func is not None:
self._flops_bin_num = flops_bin_num
self._flops_bins = [flops_lb + (flops_ub - flops_lb) / flops_bin_num * i for i in range(flops_bin_num + 1)]
self._flops_sample_timeout = flops_sample_timeout
def sample_search(self):
"""
Sample a candidate for training. When `flops_func` is not None, candidates will be sampled uniformly
relative to flops.
Returns
-------
dict
"""
if self._flops_func is not None:
for times in range(self._flops_sample_timeout):
idx = np.random.randint(self._flops_bin_num)
cand = super().sample_search()
if self._flops_bins[idx] <= self._flops_func(cand) <= self._flops_bins[idx + 1]:
_logger.debug("Sampled candidate flops %f in %d times.", cand, times)
return cand
_logger.warning("Failed to sample a flops-valid candidate within %d tries.", self._flops_sample_timeout)
return super().sample_search()
def sample_final(self):
"""
Implement only to suffice the interface of Mutator.
"""
return self.sample_search()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import SPOSSupernetTrainingMutator
logger = logging.getLogger(__name__)
class SPOSSupernetTrainer(Trainer):
"""
This trainer trains a supernet that can be used for evolution search.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : nni.nas.pytorch.mutator.Mutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterable
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
dataset_valid : iterable
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
assert torch.cuda.is_available()
super().__init__(model, mutator if mutator is not None else SPOSSupernetTrainingMutator(model),
loss, metrics, optimizer, num_epochs, None, None,
batch_size, workers, device, log_frequency, callbacks)
self.train_loader = train_loader
self.valid_loader = valid_loader
def train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset()
logits = self.model(x)
loss = self.loss(logits, y)
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import get_and_apply_next_architecture
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
import json
import logging
import os
import sys
import tensorflow as tf
import nni
from nni.runtime.env_vars import trial_env_vars
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
from nni.nas.tensorflow.mutator import Mutator
logger = logging.getLogger(__name__)
NNI_GEN_SEARCH_SPACE = "NNI_GEN_SEARCH_SPACE"
LAYER_CHOICE = "layer_choice"
INPUT_CHOICE = "input_choice"
def get_and_apply_next_architecture(model):
"""
Wrapper of :class:`~nni.nas.tensorflow.classic_nas.mutator.ClassicMutator` to make it more meaningful,
similar to ``get_next_parameter`` for HPO.
Tt will generate search space based on ``model``.
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ``nnictl`` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
ClassicMutator(model)
class ClassicMutator(Mutator):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
def __init__(self, model):
super(ClassicMutator, self).__init__(model)
self._chosen_arch = {}
self._search_space = self._generate_search_space()
if NNI_GEN_SEARCH_SPACE in os.environ:
# dry run for only generating search space
self._dump_search_space(os.environ[NNI_GEN_SEARCH_SPACE])
sys.exit(0)
if trial_env_vars.NNI_PLATFORM is None:
logger.warning("This is in standalone mode, the chosen are the first one(s).")
self._chosen_arch = self._standalone_generate_chosen()
else:
# get chosen arch from tuner
self._chosen_arch = nni.get_next_parameter()
if self._chosen_arch is None:
if trial_env_vars.NNI_PLATFORM == "unittest":
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger.warning("`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode.")
self._chosen_arch = self._standalone_generate_chosen()
else:
raise RuntimeError("Chosen architecture is None. This may be a platform error.")
self.reset()
def _sample_layer_choice(self, mutable, idx, value, search_space_item):
"""
Convert layer choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
# doesn't support multihot for layer choice yet
assert 0 <= idx < len(mutable) and search_space_item[idx] == value, \
"Index '{}' in search space '{}' is not '{}'".format(idx, search_space_item, value)
mask = tf.one_hot(idx, len(mutable))
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
def _sample_input_choice(self, mutable, idx, value, search_space_item):
"""
Convert input choice to tensor representation.
Parameters
----------
mutable : Mutable
idx : int
Number `idx` of list will be selected.
value : str
The verbose representation of the selected value.
search_space_item : list
The list for corresponding search space.
"""
candidate_repr = search_space_item["candidates"]
multihot_list = [False] * mutable.n_candidates
for i, v in zip(idx, value):
assert 0 <= i < mutable.n_candidates and candidate_repr[i] == v, \
"Index '{}' in search space '{}' is not '{}'".format(i, candidate_repr, v)
assert not multihot_list[i], "'{}' is selected twice in '{}', which is not allowed.".format(i, idx)
multihot_list[i] = True
return tf.cast(multihot_list, tf.bool) # pylint: disable=not-callable
def sample_search(self):
"""
See :meth:`sample_final`.
"""
return self.sample_final()
def sample_final(self):
"""
Convert the chosen arch and apply it on model.
"""
assert set(self._chosen_arch.keys()) == set(self._search_space.keys()), \
"Unmatched keys, expected keys '{}' from search space, found '{}'.".format(self._search_space.keys(),
self._chosen_arch.keys())
result = dict()
for mutable in self.mutables:
if isinstance(mutable, (LayerChoice, InputChoice)):
assert mutable.key in self._chosen_arch, \
"Expected '{}' in chosen arch, but not found.".format(mutable.key)
data = self._chosen_arch[mutable.key]
assert isinstance(data, dict) and "_value" in data and "_idx" in data, \
"'{}' is not a valid choice.".format(data)
if isinstance(mutable, LayerChoice):
result[mutable.key] = self._sample_layer_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, InputChoice):
result[mutable.key] = self._sample_input_choice(mutable, data["_idx"], data["_value"],
self._search_space[mutable.key]["_value"])
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during parsing choices.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return result
def _standalone_generate_chosen(self):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice.
::
{ key_name: {"_value": "conv1",
"_idx": 0} }
{ key_name: {"_value": ["in1"],
"_idx": [0]} }
Returns
-------
dict
the chosen architecture
"""
chosen_arch = {}
for key, val in self._search_space.items():
if val["_type"] == LAYER_CHOICE:
choices = val["_value"]
chosen_arch[key] = {"_value": choices[0], "_idx": 0}
elif val["_type"] == INPUT_CHOICE:
choices = val["_value"]["candidates"]
n_chosen = val["_value"]["n_chosen"]
if n_chosen is None:
n_chosen = len(choices)
chosen_arch[key] = {"_value": choices[:n_chosen], "_idx": list(range(n_chosen))}
else:
raise ValueError("Unknown key '%s' and value '%s'." % (key, val))
return chosen_arch
def _generate_search_space(self):
"""
Generate search space from mutables.
Here is the search space format:
::
{ key_name: {"_type": "layer_choice",
"_value": ["conv1", "conv2"]} }
{ key_name: {"_type": "input_choice",
"_value": {"candidates": ["in1", "in2"],
"n_chosen": 1}} }
Returns
-------
dict
the generated search space
"""
search_space = {}
for mutable in self.mutables:
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = mutable.names
search_space[key] = {"_type": LAYER_CHOICE, "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": INPUT_CHOICE,
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
elif isinstance(mutable, MutableScope):
logger.info("Mutable scope '%s' is skipped during generating search space.", mutable.key)
else:
raise TypeError("Unsupported mutable type: '%s'." % type(mutable))
return search_space
def _dump_search_space(self, file_path):
with open(file_path, "w") as ss_file:
json.dump(self._search_space, ss_file, sort_keys=True, indent=2)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import EnasMutator
from .trainer import EnasTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LSTMCell, RNN
from tensorflow.keras.losses import SparseCategoricalCrossentropy, Reduction
from nni.nas.tensorflow.mutator import Mutator
from nni.nas.tensorflow.mutables import LayerChoice, InputChoice, MutableScope
class EnasMutator(Mutator):
def __init__(self, model,
lstm_size=64,
lstm_num_layers=1,
tanh_constant=1.5,
cell_exit_extra_step=False,
skip_target=0.4,
temperature=None,
branch_bias=0.25,
entropy_reduction='sum'):
super().__init__(model)
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
cells = [LSTMCell(units=lstm_size, use_bias=False) for _ in range(lstm_num_layers)]
self.lstm = RNN(cells, stateful=True)
self.g_emb = tf.random.normal((1, 1, lstm_size)) * 0.1
self.skip_targets = tf.constant([1.0 - skip_target, skip_target])
self.max_layer_choice = 0
self.bias_dict = {}
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
if 'reduce' in mutable.key:
bias = []
for choice in mutable.choices:
if 'conv' in str(type(choice)).lower():
bias.append(branch_bias)
else:
bias.append(-branch_bias)
self.bias_dict[mutable.key] = tf.constant(bias)
# exposed for trainer
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
# internal nn layers
self.embedding = Embedding(self.max_layer_choice + 1, lstm_size)
self.soft = Dense(self.max_layer_choice, use_bias=False)
self.attn_anchor = Dense(lstm_size, use_bias=False)
self.attn_query = Dense(lstm_size, use_bias=False)
self.v_attn = Dense(1, use_bias=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = tf.reduce_sum if entropy_reduction == 'sum' else tf.reduce_mean
self.cross_entropy_loss = SparseCategoricalCrossentropy(from_logits=True, reduction=Reduction.NONE)
self._first_sample = True
def sample_search(self):
self._initialize()
self._sample(self.mutables)
self._first_sample = False
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if self.cell_exit_extra_step and isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
self._anchors_hid[mutable.key] = self.lstm(self._inputs, 1)
def _initialize(self):
self._choices = {}
self._anchors_hid = {}
self._inputs = self.g_emb
# seems the `input_shape` parameter of RNN does not work
# workaround it by omitting `reset_states` for first run
if not self._first_sample:
self.lstm.reset_states()
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _sample_layer_choice(self, mutable):
logit = self.soft(self.lstm(self._inputs))
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * tf.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
branch_id = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [1])
log_prob = self.cross_entropy_loss(branch_id, logit)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.math.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = tf.reshape(self.embedding(branch_id), [1, 1, -1])
mask = tf.one_hot(branch_id, self.max_layer_choice)
return tf.cast(tf.reshape(mask, [-1]), tf.bool)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._anchors_hid[label] = self.lstm(self._inputs)
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = tf.concat(query, axis=0)
query = tf.tanh(query + self.attn_query(anchors[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * tf.tanh(query)
if mutable.n_chosen is None:
logit = tf.concat([-query, query], axis=1)
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
skip = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip_prob = tf.math.sigmoid(logit)
kl = tf.reduce_sum(skip_prob * tf.math.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(skip, logit)
skip = tf.cast(skip, tf.float32)
inputs = tf.tensordot(skip, tf.concat(anchors, 0), 1) / (1. + tf.reduce_sum(skip))
self._inputs = tf.reshape(inputs, [1, 1, -1])
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = tf.reshape(query, [1, -1])
softmax_logit = tf.math.log(tf.nn.softmax(logit, axis=-1))
index = tf.reshape(tf.random.categorical(softmax_logit, num_samples=1), [-1])
skip = tf.reshape(tf.one_hot(index, mutable.n_candidates), [-1])
# when the size is 1, tf does not accept tensor here, complaining the shape is wrong
# but using a numpy array seems fine
log_prob = self.cross_entropy_loss(logit, query.numpy())
self._inputs = tf.reshape(anchors[index.numpy()[0]], [1, 1, -1])
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = log_prob * tf.exp(-log_prob)
self.sample_entropy += self.entropy_reduction(entropy)
assert len(skip) == mutable.n_candidates, (skip, mutable.n_candidates, mutable.n_chosen)
return tf.cast(skip, tf.bool)
This diff is collapsed.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .execution import *
from .fixed import fixed_arch
from .mutable import *
from .utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
from .evaluator import Evaluator
from .functional import FunctionalEvaluator
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
__all__ = ['Evaluator']
import abc
from typing import Any, Callable, Type, Union, cast
class Evaluator(abc.ABC):
"""
Evaluator of a model. An evaluator should define where the training code is, and the configuration of
training code. The configuration includes basic runtime information trainer needs to know (such as number of GPUs)
or tune-able parameters (such as learning rate), depending on the implementation of training code.
Each config should define how it is interpreted in ``_execute()``, taking only one argument which is the mutated model class.
For example, functional evaluator might directly import the function and call the function.
"""
def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
"""
return self._execute(model_cls)
def __repr__(self):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
@staticmethod
def _load(ir: Any) -> 'Evaluator':
evaluator_type = ir.get('type')
if isinstance(evaluator_type, str):
# for debug purposes only
for subclass in Evaluator.__subclasses__():
if subclass.__name__ == evaluator_type:
evaluator_type = subclass
break
assert issubclass(cast(type, evaluator_type), Evaluator)
return cast(Type[Evaluator], evaluator_type)._load(ir)
@abc.abstractmethod
def _dump(self) -> Any:
"""
Subclass implements ``_dump`` for their own serialization.
They should return a dict, with a key ``type`` which equals ``self.__class__``,
and optionally other keys.
"""
pass
@abc.abstractmethod
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
def __eq__(self, other) -> bool:
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import nni
from .evaluator import Evaluator
@nni.trace
class FunctionalEvaluator(Evaluator):
"""
Functional evaluator that directly takes a function and thus should be general.
Attributes
----------
function
The full name of the function.
arguments
Keyword arguments for the function other than model.
"""
def __init__(self, function, **kwargs):
self.function = function
self.arguments = kwargs
@staticmethod
def _load(ir):
return FunctionalEvaluator(ir['function'], **ir['arguments'])
def _dump(self):
return {
'type': self.__class__,
'function': self.function,
'arguments': self.arguments
}
def _execute(self, model_cls):
return self.function(model_cls, **self.arguments)
def __eq__(self, other):
return self.function == other.function and self.arguments == other.arguments
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .mutator import RandomMutator from .lightning import *
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .nasbench201 import NASBench201Cell from .api import *
from .common import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment