Unverified Commit a7846135 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Enable `fixed_arch` on Retiarii (#3972)

parent 08fe2924
...@@ -106,3 +106,5 @@ Utilities ...@@ -106,3 +106,5 @@ Utilities
--------- ---------
.. autofunction:: nni.retiarii.serialize .. autofunction:: nni.retiarii.serialize
.. autofunction:: nni.retiarii.fixed_arch
...@@ -34,4 +34,10 @@ See `API reference <./ApiReference.rst>`__ for detailed usages. Here, we show an ...@@ -34,4 +34,10 @@ See `API reference <./ApiReference.rst>`__ for detailed usages. Here, we show an
trainer.fit() trainer.fit()
final_architecture = trainer.export() final_architecture = trainer.export()
**Format of the exported architecture.** TBD. After the searching is done, we can use the exported architecture to instantiate the full network for retraining. Here is an example:
.. code-block:: python
from nni.retiarii import fixed_arch
with fixed_arch('/path/to/checkpoint.json'):
model = Model()
...@@ -16,7 +16,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin ...@@ -16,7 +16,7 @@ A typical example is DartsTrainer, where learnable-parameters are used to combin
class DartsLayerChoice(nn.Module): class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice): def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__() super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key self.name = layer_choice.label
self.op_choices = nn.ModuleDict(layer_choice.named_children()) self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3) self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import ops import ops
from nni.nas.pytorch import mutables from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
class AuxiliaryHead(nn.Module): class AuxiliaryHead(nn.Module):
...@@ -45,7 +45,7 @@ class Node(nn.Module): ...@@ -45,7 +45,7 @@ class Node(nn.Module):
stride = 2 if i < num_downsample_connect else 1 stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i)) choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append( self.ops.append(
mutables.LayerChoice(OrderedDict([ LayerChoice(OrderedDict([
("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)), ("maxpool", ops.PoolBN('max', channels, 3, stride, 1, affine=False)),
("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)), ("avgpool", ops.PoolBN('avg', channels, 3, stride, 1, affine=False)),
("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)), ("skipconnect", nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False)),
...@@ -53,9 +53,9 @@ class Node(nn.Module): ...@@ -53,9 +53,9 @@ class Node(nn.Module):
("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)), ("sepconv5x5", ops.SepConv(channels, channels, 5, stride, 2, affine=False)),
("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)), ("dilconv3x3", ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)) ("dilconv5x5", ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False))
]), key=choice_keys[-1])) ]), label=choice_keys[-1]))
self.drop_path = ops.DropPath() self.drop_path = ops.DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id)) self.input_switch = InputChoice(n_candidates=len(choice_keys), n_chosen=2, label="{}_switch".format(node_id))
def forward(self, prev_nodes): def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes) assert len(self.ops) == len(prev_nodes)
......
...@@ -12,8 +12,8 @@ from torch.utils.tensorboard import SummaryWriter ...@@ -12,8 +12,8 @@ from torch.utils.tensorboard import SummaryWriter
import datasets import datasets
import utils import utils
from model import CNN from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter from nni.nas.pytorch.utils import AverageMeter
from nni.retiarii import fixed_arch
logger = logging.getLogger('nni') logger = logging.getLogger('nni')
...@@ -119,8 +119,8 @@ if __name__ == "__main__": ...@@ -119,8 +119,8 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16) dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16)
with fixed_arch(args.arc_checkpoint):
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True) model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
apply_fixed_architecture(model, args.arc_checkpoint)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
model.to(device) model.to(device)
......
...@@ -4,5 +4,6 @@ ...@@ -4,5 +4,6 @@
from .operation import Operation from .operation import Operation
from .graph import * from .graph import *
from .execution import * from .execution import *
from .fixed import fixed_arch
from .mutator import * from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
import json
import logging
from pathlib import Path
from typing import Union, Dict, Any
from .utils import ContextStack
_logger = logging.getLogger(__name__)
def fixed_arch(fixed_arch: Union[str, Path, Dict[str, Any]], verbose=True):
"""
Load architecture from ``fixed_arch`` and apply to model. This should be used as a context manager. For example,
.. code-block:: python
with fixed_arch('/path/to/export.json'):
model = Model(3, 224, 224)
Parameters
----------
fixed_arc : str, Path or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
ContextStack
Context manager that provides a fixed architecture when creates the model.
"""
if isinstance(fixed_arch, (str, Path)):
with open(fixed_arch) as f:
fixed_arch = json.load(f)
if verbose:
_logger.info(f'Fixed architecture: %s', fixed_arch)
return ContextStack('fixed', fixed_arch)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import copy import copy
import logging import logging
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -18,8 +19,8 @@ _logger = logging.getLogger(__name__) ...@@ -18,8 +19,8 @@ _logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module): class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice): def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__() super(DartsLayerChoice, self).__init__()
self.name = layer_choice.key self.name = layer_choice.label
self.op_choices = nn.ModuleDict(layer_choice.named_children()) self.op_choices = nn.ModuleDict(OrderedDict([(name, layer_choice[name]) for name in layer_choice.names]))
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3) self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
...@@ -38,13 +39,13 @@ class DartsLayerChoice(nn.Module): ...@@ -38,13 +39,13 @@ class DartsLayerChoice(nn.Module):
yield name, p yield name, p
def export(self): def export(self):
return torch.argmax(self.alpha).item() return list(self.op_choices.keys())[torch.argmax(self.alpha).item()]
class DartsInputChoice(nn.Module): class DartsInputChoice(nn.Module):
def __init__(self, input_choice): def __init__(self, input_choice):
super(DartsInputChoice, self).__init__() super(DartsInputChoice, self).__init__()
self.name = input_choice.key self.name = input_choice.label
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3) self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1 self.n_chosen = input_choice.n_chosen or 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment