Unverified Commit a7846135 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Enable `fixed_arch` on Retiarii (#3972)

parent 08fe2924
......@@ -105,4 +105,6 @@ Retiarii Experiments
Utilities
---------
.. autofunction:: nni.retiarii.serialize
\ No newline at end of file
.. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.fixed_arch
......@@ -34,4 +34,10 @@ See `API reference <./ApiReference.rst>`__ for detailed usages. Here, we show an
trainer.fit()
final_architecture = trainer.export()
**Format of the exported architecture.** TBD.
After the searching is done, we can use the exported architecture to instantiate the full network for retraining. Here is an example:
.. code-block:: python
from nni.retiarii import fixed_arch
with fixed_arch('/path/to/checkpoint.json'):
model = Model()
......@@ -16,7 +16,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.name = layer_choice.label
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
......
......@@ -7,7 +7,7 @@ import torch
import torch.nn as nn
import ops
from nni.nas.pytorch import mutables
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
class AuxiliaryHead(nn.Module):
......@@ -45,7 +45,7 @@ class Node(nn.Module):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(OrderedDict([
LayerChoice(OrderedDict([
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
......@@ -53,9 +53,9 @@ class Node(nn.Module):
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
]), key=choice_keys[-1]))
]), label=choice_keys[-1]))
self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
self.input_switch = InputChoice(n_candidates=len(choice_keys), n_chosen=2, label="{}_switch".format(node_id))
def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
......
......@@ -12,8 +12,8 @@ from torch.utils.tensorboard import SummaryWriter
import datasets
import utils
from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
from nni.retiarii import fixed_arch
logger = logging.getLogger('nni')
......@@ -119,8 +119,8 @@ if __name__ == "__main__":
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16)
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
apply_fixed_architecture(model, args.arc_checkpoint)
with fixed_arch(args.arc_checkpoint):
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
criterion = nn.CrossEntropyLoss()
model.to(device)
......
......@@ -4,5 +4,6 @@
from .operation import Operation
from .graph import *
from .execution import *
from .fixed import fixed_arch
from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
import json
import logging
from pathlib import Path
from typing import Union, Dict, Any
from .utils import ContextStack
_logger = logging.getLogger(__name__)
def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,
.. code-block:: python
with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)
Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""
if isinstance(fixed_arch, (str, Path)):
with open(fixed_arch) as f:
fixed_arch = json.load(f)
if verbose:
_logger.info(f'Fixed architecture: %s', fixed_arch)
return ContextStack('fixed', fixed_arch)
......@@ -3,6 +3,7 @@
import copy
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
......@@ -18,8 +19,8 @@ _logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.name = layer_choice.label
self.op_choices = nn.ModuleDict(OrderedDict([(name, layer_choice[name]) for name in layer_choice.names]))
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
......@@ -38,13 +39,13 @@ class DartsLayerChoice(nn.Module):
yield name, p
def export(self):
return torch.argmax(self.alpha).item()
return list(self.op_choices.keys())[torch.argmax(self.alpha).item()]
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.name = input_choice.key
self.name = input_choice.label
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment