"doc/src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2d22efb764328a3475cf71e60dd4821cda2f6885"
Commit 1cada380 authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Extract base mutator/trainer and support ENAS micro search space (#1739)

parent 3ddab980
import json
import torch
from nni.nas.pytorch.mutator import Mutator
class FixedArchitecture(Mutator):
def __init__(self, model, fixed_arc, strict=True):
"""
Initialize a fixed architecture mutator.
Parameters
----------
model: nn.Module
A mutable network.
fixed_arc: str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict: bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super().__init__(model)
if isinstance(fixed_arc, str):
with open(fixed_arc, "r") as f:
fixed_arc = json.load(f.read())
self._fixed_arc = fixed_arc
self._strict = strict
def _encode_tensor(self, data):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: self._encode_tensor(v) for k, v in data.items()}
return data
def before_pass(self):
self._unused_key = set(self._fixed_arc.keys())
def after_pass(self):
if self._strict:
if self._unused_key:
raise ValueError("{} are never used by the network. "
"Set strict=False if you want to disable this check.".format(self._unused_key))
def _check_key(self, key):
if key not in self._fixed_arc:
raise ValueError("\"{}\" is demanded by the network, but not found in saved architecture.".format(key))
def on_calc_layer_choice_mask(self, mutable):
self._check_key(mutable.key)
return self._fixed_arc[mutable.key]
def on_calc_input_choice_mask(self, mutable, tags):
self._check_key(mutable.key)
return self._fixed_arc[mutable.key]
...@@ -3,7 +3,7 @@ import torch.nn as nn ...@@ -3,7 +3,7 @@ import torch.nn as nn
from nni.nas.utils import global_mutable_counting from nni.nas.utils import global_mutable_counting
class PyTorchMutable(nn.Module): class Mutable(nn.Module):
""" """
Mutable is designed to function as a normal layer, with all necessary operators' weights. Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself. States and weights of architectures should be included in mutator, instead of the layer itself.
...@@ -24,15 +24,11 @@ class PyTorchMutable(nn.Module): ...@@ -24,15 +24,11 @@ class PyTorchMutable(nn.Module):
self._key = key self._key = key
else: else:
self._key = self.__class__.__name__ + str(global_mutable_counting()) self._key = self.__class__.__name__ + str(global_mutable_counting())
self.name = self.key self.init_hook = self.forward_hook = None
def __deepcopy__(self, memodict=None): def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.") raise NotImplementedError("Deep copy doesn't work for mutables.")
def __enter__(self):
self._check_built()
return super().__enter__()
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
self._check_built() self._check_built()
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
...@@ -47,8 +43,16 @@ class PyTorchMutable(nn.Module): ...@@ -47,8 +43,16 @@ class PyTorchMutable(nn.Module):
def key(self): def key(self):
return self._key return self._key
@property
def name(self):
return self._name if hasattr(self, "_name") else "_key"
@name.setter
def name(self, name):
self._name = name
def similar(self, other): def similar(self, other):
return self == other return type(self) == type(other)
def _check_built(self): def _check_built(self):
if not hasattr(self, "mutator"): if not hasattr(self, "mutator"):
...@@ -56,8 +60,11 @@ class PyTorchMutable(nn.Module): ...@@ -56,8 +60,11 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__" "Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) "so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(PyTorchMutable):
class MutableScope(Mutable):
""" """
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
...@@ -67,14 +74,18 @@ class MutableScope(PyTorchMutable): ...@@ -67,14 +74,18 @@ class MutableScope(PyTorchMutable):
def __init__(self, key): def __init__(self, key):
super().__init__(key=key) super().__init__(key=key)
def __enter__(self): def build(self):
self.mutator.enter_mutable_scope(self) self.mutator.on_init_mutable_scope(self)
def __exit__(self, exc_type, exc_val, exc_tb): def __call__(self, *args, **kwargs):
self.mutator.exit_mutable_scope(self) try:
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(PyTorchMutable): class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None): def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
self.length = len(op_candidates) self.length = len(op_candidates)
...@@ -83,10 +94,10 @@ class LayerChoice(PyTorchMutable): ...@@ -83,10 +94,10 @@ class LayerChoice(PyTorchMutable):
self.return_mask = return_mask self.return_mask = return_mask
def __len__(self): def __len__(self):
return self.length return len(self.choices)
def forward(self, *inputs): def forward(self, *inputs):
out, mask = self.mutator.on_forward(self, *inputs) out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask: if self.return_mask:
return out, mask return out, mask
return out return out
...@@ -95,7 +106,7 @@ class LayerChoice(PyTorchMutable): ...@@ -95,7 +106,7 @@ class LayerChoice(PyTorchMutable):
return type(self) == type(other) and self.length == other.length return type(self) == type(other) and self.length == other.length
class InputChoice(PyTorchMutable): class InputChoice(Mutable):
def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None): def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
assert n_candidates > 0, "Number of candidates must be greater than 0." assert n_candidates > 0, "Number of candidates must be greater than 0."
...@@ -104,16 +115,21 @@ class InputChoice(PyTorchMutable): ...@@ -104,16 +115,21 @@ class InputChoice(PyTorchMutable):
self.reduction = reduction self.reduction = reduction
self.return_mask = return_mask self.return_mask = return_mask
def forward(self, optional_inputs, semantic_labels=None): def build(self):
self.mutator.on_init_input_choice(self)
def forward(self, optional_inputs, tags=None):
assert len(optional_inputs) == self.n_candidates, \ assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates." "Length of the input list must be equal to number of candidates."
if semantic_labels is None: if tags is None:
semantic_labels = ["default_label"] * self.n_candidates tags = [""] * self.n_candidates
out, mask = self.mutator.on_forward(self, optional_inputs, semantic_labels) else:
assert len(tags) == self.n_candidates, "Length of tags must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_choice(self, optional_inputs, tags)
if self.return_mask: if self.return_mask:
return out, mask return out, mask
return out return out
def similar(self, other): def similar(self, other):
return type(self) == type(other) and \ return type(self) == type(other) and \
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
import logging
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.mutables import PyTorchMutable from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.utils import to_snake_case
logger = logging.getLogger(__name__)
class Mutator(BaseMutator, nn.Module):
class PyTorchMutator(nn.Module): def export(self):
def __init__(self, model): if self._in_forward_pass:
super().__init__() raise RuntimeError("Still in forward pass. Exporting might induce incompleteness.")
self.before_build(model) if not self._cache:
self.parse_search_space(model) raise RuntimeError("No running history found. You need to call your model at least once before exporting. "
self.after_build(model) "You might also want to check if there are no valid mutables in your model.")
return self._cache
def before_build(self, model):
pass
def after_build(self, model):
pass
def named_mutables(self, model):
# if distinct is true, the method will filter out those with duplicated keys
key2module = dict()
for name, module in model.named_modules():
if isinstance(module, PyTorchMutable):
distinct = False
if module.key in key2module:
assert key2module[module.key].similar(module), \
"Mutable \"{}\" that share the same key must be similar to each other".format(module.key)
else:
distinct = True
key2module[module.key] = module
yield name, module, distinct
def __setattr__(self, key, value):
if key in ["model", "net", "network"]:
logger.warning("Think twice if you are including the network into mutator.")
return super().__setattr__(key, value)
def parse_search_space(self, model):
for name, mutable, distinct in self.named_mutables(model):
mutable.name = name
mutable.set_mutator(self)
if not distinct:
continue
init_method_name = "on_init_{}".format(to_snake_case(mutable.__class__.__name__))
if hasattr(self, init_method_name) and callable(getattr(self, init_method_name)):
getattr(self, init_method_name)(mutable)
else:
# fallback to general init
self.on_init_general(mutable)
def on_init_general(self, mutable):
pass
@contextmanager @contextmanager
def forward_pass(self): def forward_pass(self):
self._in_forward_pass = True
self._cache = dict()
self.before_pass() self.before_pass()
try: try:
yield self yield self
finally: finally:
self.after_pass() self.after_pass()
self._in_forward_pass = False
def before_pass(self): def before_pass(self):
self._in_forward_pass = True
self._cache = dict()
def after_pass(self):
self._in_forward_pass = False
def enter_mutable_scope(self, mutable_scope):
pass pass
def exit_mutable_scope(self, mutable_scope): def after_pass(self):
pass pass
def forward(self, *inputs): def _check_in_forward_pass(self):
raise NotImplementedError("Mutator is not forward-able")
def on_forward(self, mutable, *inputs):
"""Callback on forwarding a mutable"""
if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass: if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass:
raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call " raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?") "super().before_pass() and after_pass() in your override method?")
forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__))
if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)):
return getattr(self, forward_method_name)(mutable, *inputs)
else:
# fallback to general forward
return self.on_forward_general(mutable, *inputs)
def on_forward_general(self, mutable, *inputs):
raise NotImplementedError("Forward has to be implemented")
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *inputs):
""" """
Callback of layer choice forward. Override if you are an advanced user. Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy speicified (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters Parameters
...@@ -111,33 +52,38 @@ class PyTorchMutator(nn.Module): ...@@ -111,33 +52,38 @@ class PyTorchMutator(nn.Module):
Returns Returns
------- -------
torch.Tensor tuple of torch.Tensor and torch.Tensor
""" """
self._check_in_forward_pass()
def _map_fn(op, *inputs): def _map_fn(op, *inputs):
return op(*inputs) return op(*inputs)
mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable)) mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask) out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list, semantic_labels): def on_forward_input_choice(self, mutable, tensor_list, tags):
""" """
Callback of input choice forward. Override if you are an advanced user. Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `semantic_labels` On default, this method calls :meth:`on_calc_input_choice_mask` with `tags`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy speicified in `mutable.reduction`. It will also cache the the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`. mask with corresponding `mutable.key`.
Parameters Parameters
---------- ----------
mutable: InputChoice mutable: InputChoice
inputs: list of torch.Tensor tensor_list: list of torch.Tensor
tags: list of string
Returns Returns
------- -------
torch.Tensor tuple of torch.Tensor and torch.Tensor
""" """
mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, semantic_labels)) self._check_in_forward_pass()
out = self._select_with_mask(lambda x: x, [(t, ) for t in tensor_list], mask) mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, tags))
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask return self._tensor_reduction(mutable.reduction, out), mask
def on_calc_layer_choice_mask(self, mutable): def on_calc_layer_choice_mask(self, mutable):
...@@ -157,7 +103,7 @@ class PyTorchMutator(nn.Module): ...@@ -157,7 +103,7 @@ class PyTorchMutator(nn.Module):
""" """
raise NotImplementedError("Layer choice mask calculation must be implemented") raise NotImplementedError("Layer choice mask calculation must be implemented")
def on_calc_input_choice_mask(self, mutable, semantic_labels): def on_calc_input_choice_mask(self, mutable, tags):
""" """
Recommended to override. Calculate a mask tensor for a input choice. Recommended to override. Calculate a mask tensor for a input choice.
...@@ -165,7 +111,7 @@ class PyTorchMutator(nn.Module): ...@@ -165,7 +111,7 @@ class PyTorchMutator(nn.Module):
---------- ----------
mutable: InputChoice mutable: InputChoice
Corresponding input choice object. Corresponding input choice object.
semantic_labels: list of string tags: list of string
The name of labels of input tensors given by user. Usually it's a The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user. :class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
...@@ -179,7 +125,6 @@ class PyTorchMutator(nn.Module): ...@@ -179,7 +125,6 @@ class PyTorchMutator(nn.Module):
def _select_with_mask(self, map_fn, candidates, mask): def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type(): if "BoolTensor" in mask.type():
# print(candidates[0], len(mask))
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type(): elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)] out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)]
......
...@@ -33,13 +33,13 @@ class PdartsTrainer(Trainer): ...@@ -33,13 +33,13 @@ class PdartsTrainer(Trainer):
for epoch in range(self.pdarts_epoch): for epoch in range(self.pdarts_epoch):
layers = self.layers+self.pdarts_num_layers[epoch] layers = self.layers+self.pdarts_num_layers[epoch]
model, loss, model_optim, lr_scheduler = self.model_creator( model, loss, model_optim, _ = self.model_creator(
layers, n_nodes) layers, n_nodes)
mutator = PdartsMutator( mutator = PdartsMutator(
model, epoch, self.pdarts_num_to_drop, switches) model, epoch, self.pdarts_num_to_drop, switches)
self.trainer = DartsTrainer(model, loss=loss, model_optim=model_optim, self.trainer = DartsTrainer(model, loss=loss, optimizer=model_optim,
lr_scheduler=lr_scheduler, mutator=mutator, **self.darts_parameters) mutator=mutator, **self.darts_parameters)
print("start pdrats training %s..." % epoch) print("start pdrats training %s..." % epoch)
self.trainer.train() self.trainer.train()
......
from abc import ABC, abstractmethod from abc import abstractmethod
import torch
class Trainer(ABC): from .base_trainer import BaseTrainer
class Trainer(BaseTrainer):
def __init__(self, model, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency,
mutator, callbacks):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.mutator = mutator
self.model.to(self.device)
self.loss.to(self.device)
self.mutator.to(self.device)
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.batch_size = batch_size
self.workers = workers
self.log_frequency = log_frequency
self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks:
callback.build(self.model, self.mutator, self)
@abstractmethod @abstractmethod
def train(self): def train_one_epoch(self, epoch):
raise NotImplementedError pass
@abstractmethod @abstractmethod
def export(self): def validate_one_epoch(self, epoch):
raise NotImplementedError pass
def _train(self, validate):
for epoch in range(self.num_epochs):
for callback in self.callbacks:
callback.on_epoch_begin(epoch)
# training
print("Epoch {} Training".format(epoch))
self.train_one_epoch(epoch)
if validate:
# validation
print("Epoch {} Validating".format(epoch))
self.validate_one_epoch(epoch)
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def train_and_validate(self):
self._train(True)
def train(self):
self._train(False)
def validate(self):
self.validate_one_epoch(-1)
import re
from collections import OrderedDict from collections import OrderedDict
import torch
_counter = 0 _counter = 0
...@@ -12,14 +9,6 @@ def global_mutable_counting(): ...@@ -12,14 +9,6 @@ def global_mutable_counting():
return _counter return _counter
def to_snake_case(camel_case):
return re.sub('(?!^)([A-Z]+)', r'_\1', camel_case).lower()
def auto_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class AverageMeterGroup(object): class AverageMeterGroup(object):
def __init__(self): def __init__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment