Commit 1cada380 authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Extract base mutator/trainer and support ENAS micro search space (#1739)

parent 3ddab980
import json
import torch
from nni.nas.pytorch.mutator import Mutator
class FixedArchitecture(Mutator):
def __init__(self, model, fixed_arc, strict=True):
"""
Initialize a fixed architecture mutator.
Parameters
----------
model: nn.Module
A mutable network.
fixed_arc: str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict: bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super().__init__(model)
if isinstance(fixed_arc, str):
with open(fixed_arc, "r") as f:
fixed_arc = json.load(f.read())
self._fixed_arc = fixed_arc
self._strict = strict
def _encode_tensor(self, data):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: self._encode_tensor(v) for k, v in data.items()}
return data
def before_pass(self):
self._unused_key = set(self._fixed_arc.keys())
def after_pass(self):
if self._strict:
if self._unused_key:
raise ValueError("{} are never used by the network. "
"Set strict=False if you want to disable this check.".format(self._unused_key))
def _check_key(self, key):
if key not in self._fixed_arc:
raise ValueError("\"{}\" is demanded by the network, but not found in saved architecture.".format(key))
def on_calc_layer_choice_mask(self, mutable):
self._check_key(mutable.key)
return self._fixed_arc[mutable.key]
def on_calc_input_choice_mask(self, mutable, tags):
self._check_key(mutable.key)
return self._fixed_arc[mutable.key]
......@@ -3,7 +3,7 @@ import torch.nn as nn
from nni.nas.utils import global_mutable_counting
class PyTorchMutable(nn.Module):
class Mutable(nn.Module):
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
......@@ -24,15 +24,11 @@ class PyTorchMutable(nn.Module):
self._key = key
else:
self._key = self.__class__.__name__ + str(global_mutable_counting())
self.name = self.key
self.init_hook = self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __enter__(self):
self._check_built()
return super().__enter__()
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
......@@ -47,8 +43,16 @@ class PyTorchMutable(nn.Module):
def key(self):
return self._key
@property
def name(self):
return self._name if hasattr(self, "_name") else "_key"
@name.setter
def name(self, name):
self._name = name
def similar(self, other):
return self == other
return type(self) == type(other)
def _check_built(self):
if not hasattr(self, "mutator"):
......@@ -56,8 +60,11 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(PyTorchMutable):
class MutableScope(Mutable):
"""
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
......@@ -67,14 +74,18 @@ class MutableScope(PyTorchMutable):
def __init__(self, key):
super().__init__(key=key)
def __enter__(self):
self.mutator.enter_mutable_scope(self)
def build(self):
self.mutator.on_init_mutable_scope(self)
def __exit__(self, exc_type, exc_val, exc_tb):
self.mutator.exit_mutable_scope(self)
def __call__(self, *args, **kwargs):
try:
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(PyTorchMutable):
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
......@@ -83,10 +94,10 @@ class LayerChoice(PyTorchMutable):
self.return_mask = return_mask
def __len__(self):
return self.length
return len(self.choices)
def forward(self, *inputs):
out, mask = self.mutator.on_forward(self, *inputs)
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
......@@ -95,7 +106,7 @@ class LayerChoice(PyTorchMutable):
return type(self) == type(other) and self.length == other.length
class InputChoice(PyTorchMutable):
class InputChoice(Mutable):
def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None):
super().__init__(key=key)
assert n_candidates > 0, "Number of candidates must be greater than 0."
......@@ -104,16 +115,21 @@ class InputChoice(PyTorchMutable):
self.reduction = reduction
self.return_mask = return_mask
def forward(self, optional_inputs, semantic_labels=None):
def build(self):
self.mutator.on_init_input_choice(self)
def forward(self, optional_inputs, tags=None):
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
if semantic_labels is None:
semantic_labels = ["default_label"] * self.n_candidates
out, mask = self.mutator.on_forward(self, optional_inputs, semantic_labels)
if tags is None:
tags = [""] * self.n_candidates
else:
assert len(tags) == self.n_candidates, "Length of tags must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_choice(self, optional_inputs, tags)
if self.return_mask:
return out, mask
return out
def similar(self, other):
return type(self) == type(other) and \
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
import logging
from contextlib import contextmanager
import torch
import torch.nn as nn
from nni.nas.pytorch.mutables import PyTorchMutable
from nni.nas.utils import to_snake_case
from nni.nas.pytorch.base_mutator import BaseMutator
logger = logging.getLogger(__name__)
class Mutator(BaseMutator, nn.Module):
class PyTorchMutator(nn.Module):
def __init__(self, model):
super().__init__()
self.before_build(model)
self.parse_search_space(model)
self.after_build(model)
def before_build(self, model):
pass
def after_build(self, model):
pass
def named_mutables(self, model):
# if distinct is true, the method will filter out those with duplicated keys
key2module = dict()
for name, module in model.named_modules():
if isinstance(module, PyTorchMutable):
distinct = False
if module.key in key2module:
assert key2module[module.key].similar(module), \
"Mutable \"{}\" that share the same key must be similar to each other".format(module.key)
else:
distinct = True
key2module[module.key] = module
yield name, module, distinct
def __setattr__(self, key, value):
if key in ["model", "net", "network"]:
logger.warning("Think twice if you are including the network into mutator.")
return super().__setattr__(key, value)
def parse_search_space(self, model):
for name, mutable, distinct in self.named_mutables(model):
mutable.name = name
mutable.set_mutator(self)
if not distinct:
continue
init_method_name = "on_init_{}".format(to_snake_case(mutable.__class__.__name__))
if hasattr(self, init_method_name) and callable(getattr(self, init_method_name)):
getattr(self, init_method_name)(mutable)
else:
# fallback to general init
self.on_init_general(mutable)
def on_init_general(self, mutable):
pass
def export(self):
if self._in_forward_pass:
raise RuntimeError("Still in forward pass. Exporting might induce incompleteness.")
if not self._cache:
raise RuntimeError("No running history found. You need to call your model at least once before exporting. "
"You might also want to check if there are no valid mutables in your model.")
return self._cache
@contextmanager
def forward_pass(self):
self._in_forward_pass = True
self._cache = dict()
self.before_pass()
try:
yield self
finally:
self.after_pass()
self._in_forward_pass = False
def before_pass(self):
self._in_forward_pass = True
self._cache = dict()
def after_pass(self):
self._in_forward_pass = False
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
def after_pass(self):
pass
def forward(self, *inputs):
raise NotImplementedError("Mutator is not forward-able")
def on_forward(self, mutable, *inputs):
"""Callback on forwarding a mutable"""
def _check_in_forward_pass(self):
if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass:
raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?")
forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__))
if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)):
return getattr(self, forward_method_name)(mutable, *inputs)
else:
# fallback to general forward
return self.on_forward_general(mutable, *inputs)
def on_forward_general(self, mutable, *inputs):
raise NotImplementedError("Forward has to be implemented")
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy speicified
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
......@@ -111,33 +52,38 @@ class PyTorchMutator(nn.Module):
Returns
-------
torch.Tensor
tuple of torch.Tensor and torch.Tensor
"""
self._check_in_forward_pass()
def _map_fn(op, *inputs):
return op(*inputs)
mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list, semantic_labels):
def on_forward_input_choice(self, mutable, tensor_list, tags):
"""
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `semantic_labels`
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy speicified in `mutable.reduction`. It will also cache the
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters
----------
mutable: InputChoice
inputs: list of torch.Tensor
tensor_list: list of torch.Tensor
tags: list of string
Returns
-------
torch.Tensor
tuple of torch.Tensor and torch.Tensor
"""
mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, semantic_labels))
out = self._select_with_mask(lambda x: x, [(t, ) for t in tensor_list], mask)
self._check_in_forward_pass()
mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, tags))
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_calc_layer_choice_mask(self, mutable):
......@@ -157,7 +103,7 @@ class PyTorchMutator(nn.Module):
"""
raise NotImplementedError("Layer choice mask calculation must be implemented")
def on_calc_input_choice_mask(self, mutable, semantic_labels):
def on_calc_input_choice_mask(self, mutable, tags):
"""
Recommended to override. Calculate a mask tensor for a input choice.
......@@ -165,7 +111,7 @@ class PyTorchMutator(nn.Module):
----------
mutable: InputChoice
Corresponding input choice object.
semantic_labels: list of string
tags: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
......@@ -179,7 +125,6 @@ class PyTorchMutator(nn.Module):
def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type():
# print(candidates[0], len(mask))
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)]
......
......@@ -33,13 +33,13 @@ class PdartsTrainer(Trainer):
for epoch in range(self.pdarts_epoch):
layers = self.layers+self.pdarts_num_layers[epoch]
model, loss, model_optim, lr_scheduler = self.model_creator(
model, loss, model_optim, _ = self.model_creator(
layers, n_nodes)
mutator = PdartsMutator(
model, epoch, self.pdarts_num_to_drop, switches)
self.trainer = DartsTrainer(model, loss=loss, model_optim=model_optim,
lr_scheduler=lr_scheduler, mutator=mutator, **self.darts_parameters)
self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim,
mutator=mutator, **self.darts_parameters)
print("start pdrats training %s..." % epoch)
self.trainer.train()
......
from abc import ABC, abstractmethod
from abc import abstractmethod
import torch
class Trainer(ABC):
from .base_trainer import BaseTrainer
class Trainer(BaseTrainer):
def __init__(self, model, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency,
mutator, callbacks):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.mutator = mutator
self.model.to(self.device)
self.loss.to(self.device)
self.mutator.to(self.device)
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.batch_size = batch_size
self.workers = workers
self.log_frequency = log_frequency
self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks:
callback.build(self.model, self.mutator, self)
@abstractmethod
def train(self):
raise NotImplementedError
def train_one_epoch(self, epoch):
pass
@abstractmethod
def export(self):
raise NotImplementedError
def validate_one_epoch(self, epoch):
pass
def _train(self, validate):
for epoch in range(self.num_epochs):
for callback in self.callbacks:
callback.on_epoch_begin(epoch)
# training
print("Epoch {} Training".format(epoch))
self.train_one_epoch(epoch)
if validate:
# validation
print("Epoch {} Validating".format(epoch))
self.validate_one_epoch(epoch)
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def train_and_validate(self):
self._train(True)
def train(self):
self._train(False)
def validate(self):
self.validate_one_epoch(-1)
import re
from collections import OrderedDict
import torch
_counter = 0
......@@ -12,14 +9,6 @@ def global_mutable_counting():
return _counter
def to_snake_case(camel_case):
return re.sub('(?!^)([A-Z]+)', r'_\1', camel_case).lower()
def auto_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class AverageMeterGroup(object):
def __init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment