Unverified Commit 77bc0e33 authored by Tab Zhang's avatar Tab Zhang Committed by GitHub
Browse files

Search Space Zoo nas bench 201 (#2766)

parent ec5af41f
......@@ -64,6 +64,7 @@ Search Space Zoo contains the following NAS cells:
* [DartsCell](./SearchSpaceZoo.md#DartsCell)
* [ENAS micro](./SearchSpaceZoo.md#ENASMicroLayer)
* [ENAS macro](./SearchSpaceZoo.md#ENASMacroLayer)
* [NAS Bench 201](./SearchSpaceZoo.md#nas-bench-201)
## Using NNI API to Write Your Search Space
......
......@@ -2,13 +2,13 @@
## DartsCell
DartsCell is extracted from [CNN model](./DARTS.md) designed [here](https://github.com/microsoft/nni/tree/master/examples/nas/darts). A DartsCell is a directed acyclic graph containing an ordered sequence of N nodes and each node stands for a latent representation (e.g. feature map in a convolutional network). Directed edges from Node 1 to Node 2 are associated with some operations that transform Node 1 and the result is stored on Node 2. The [operations](#darts-predefined-operations) between nodes is predefined and unchangeable. One edge represents an operation that chosen from the predefined ones to be applied to the starting node of the edge. One cell contains two input nodes, a single output node, and other `n_node` nodes. The input nodes are defined as the cell outputs in the previous two layers. The output of the cell is obtained by applying a reduction operation (e.g. concatenation) to all the intermediate nodes. To make the search space continuous, the categorical choice of a particular operation is relaxed to a softmax over all possible operations. By adjusting the weight of softmax on every node, the operation with the highest probability is chosen to be part of the final structure. A CNN model can be formed by stacking several cells together, which builds a search space. Note that, in DARTS paper all cells in the model share the same structure.
DartsCell is extracted from [CNN model](./DARTS.md) designed [here](https://github.com/microsoft/nni/tree/master/examples/nas/darts). A DartsCell is a directed acyclic graph containing an ordered sequence of N nodes and each node stands for a latent representation (e.g. feature map in a convolutional network). Directed edges from Node 1 to Node 2 are associated with some operations that transform Node 1 and the result is stored on Node 2. The [Candidate operators](#predefined-operations-darts) between nodes is predefined and unchangeable. One edge represents an operation that chosen from the predefined ones to be applied to the starting node of the edge. One cell contains two input nodes, a single output node, and other `n_node` nodes. The input nodes are defined as the cell outputs in the previous two layers. The output of the cell is obtained by applying a reduction operation (e.g. concatenation) to all the intermediate nodes. To make the search space continuous, the categorical choice of a particular operation is relaxed to a softmax over all possible operations. By adjusting the weight of softmax on every node, the operation with the highest probability is chosen to be part of the final structure. A CNN model can be formed by stacking several cells together, which builds a search space. Note that, in DARTS paper all cells in the model share the same structure.
One structure in the Darts search space is shown below. Note that, NNI merges the last one of the four intermediate nodes and the output node.
![](../../img/NAS_Darts_cell.svg)
The predefined operations are shown in [references](#predefined-operations-darts).
The predefined operators are shown [here](#predefined-operations-darts).
```eval_rst
.. autoclass:: nni.nas.pytorch.search_space_zoo.DartsCell
......@@ -28,9 +28,9 @@ python3 darts_example.py
<a name="predefined-operations-darts"></a>
### References
### Candidate operators
All supported operations for Darts are listed below.
All supported operators for Darts are listed below.
* MaxPool / AvgPool
* MaxPool: Call `torch.nn.MaxPool2d`. This operation applies a 2D max pooling over all input channels. Its parameters `kernel_size=3` and `padding=1` are fixed. The pooling result will pass through a BatchNorm2d then return as the result.
......@@ -65,11 +65,11 @@ This layer is extracted from the model designed [here](https://github.com/micros
ENAS Micro employs a DAG with N nodes in one cell, where the nodes represent local computations, and the edges represent the flow of information between the N nodes. One cell contains two input nodes and a single output node. The following nodes choose two previous nodes as input and apply two operations from [predefined ones](#predefined-operations-enas) then add them as the output of this node. For example, Node 4 chooses Node 1 and Node 3 as inputs then applies `MaxPool` and `AvgPool` on the inputs respectively, then adds and sums them as the output of Node 4. Nodes that are not served as input for any other node are viewed as the output of the layer. If there are multiple output nodes, the model will calculate the average of these nodes as the layer output.
One structure in the ENAS micro search space is shown below.
The ENAS micro search space is shown below.
![](../../img/NAS_ENAS_micro.svg)
The predefined operations can be seen [here](#predefined-operations-enas).
The predefined operators can be seen [here](#predefined-operations-enas).
```eval_rst
.. autoclass:: nni.nas.pytorch.search_space_zoo.ENASMicroLayer
......@@ -91,9 +91,9 @@ python3 enas_micro_example.py
<a name="predefined-operations-enas"></a>
### References
### Candidate operators
All supported operations for ENAS micro search are listed below.
All supported operators for ENAS micro search are listed below.
* MaxPool / AvgPool
* MaxPool: Call `torch.nn.MaxPool2d`. This operation applies a 2D max pooling over all input channels followed by BatchNorm2d. Its parameters are fixed to `kernel_size=3`, `stride=1` and `padding=1`.
......@@ -116,7 +116,7 @@ All supported operations for ENAS micro search are listed below.
## ENASMacroLayer
In Macro search, the controller makes two decisions for each layer: i) the [operation](#macro-operations) to perform on the result of the previous layer, ii) which the previous layer to connect to for SkipConnects. ENAS uses a controller to design the whole model architecture instead of one of its components. The output of operations is going to concat with the tensor of the chosen layer for SkipConnect. NNI provides [predefined operations](#macro-operations) for macro search, which are listed in [references](#macro-operations).
In Macro search, the controller makes two decisions for each layer: i) the [operation](#macro-operations) to perform on the result of the previous layer, ii) which the previous layer to connect to for SkipConnects. ENAS uses a controller to design the whole model architecture instead of one of its components. The output of operations is going to concat with the tensor of the chosen layer for SkipConnect. NNI provides [predefined operators](#macro-operations) for macro search, which are listed in [Candidate operators](#macro-operations).
Part of one structure in the ENAS macro search space is shown below.
......@@ -147,9 +147,9 @@ python3 enas_macro_example.py
<a name="macro-operations"></a>
### References
### Candidate operators
All supported operations for ENAS macro search are listed below.
All supported operators for ENAS macro search are listed below.
* ConvBranch
......@@ -172,4 +172,65 @@ All supported operations for ENAS macro search are listed below.
.. autoclass:: nni.nas.pytorch.search_space_zoo.enas_ops.PoolBranch
```
<!-- push -->
## NAS-Bench-201
NAS Bench 201 defines a unified search space, which is algorithm agnostic. The predefined skeleton consists of a stack of cells that share the same architecture. Every cell contains four nodes and a DAG is formed by connecting edges among them, where the node represents the sum of feature maps and the edge stands for an operation transforming a tensor from the source node to the target node. The predefined candidate operators can be found in [Candidate operators](#nas-bench-201-reference).
The search space of NAS Bench 201 is shown below.
![](../../img/NAS_Bench_201.svg)
```eval_rst
.. autoclass:: nni.nas.pytorch.nasbench201.NASBench201Cell
:members:
```
### Example code
[example code](https://github.com/microsoft/nni/tree/master/examples/nas/search_space_zoo/nas_bench_201.py)
```bash
# for structure searching
git clone https://github.com/Microsoft/nni.git
cd nni/examples/nas/search_space_zoo
python3 nas_bench_201.py
```
<a name="nas-bench-201-reference"></a>
### Candidate operators
All supported operators for NAS Bench 201 are listed below.
* AvgPool
If the number of input channels is not equal to the number of output channels, the input will first pass through a `ReLUConvBN` layer with `kernel_size=1`, `stride=1`, `padding=0`, and `dilation=0`.
Call `torch.nn.AvgPool2d`. This operation applies a 2D average pooling over all input channels followed by BatchNorm2d. Its parameters are fixed to `kernel_size=3` and `padding=1`.
```eval_rst
.. autoclass:: nni.nas.pytorch.nasbench201.nasbench201_ops.Pooling
:members:
```
* Conv
* Conv1x1: Consist of a sequence of ReLU, `nn.Cinv2d` and BatchNorm. The Conv operation's parameter is fixed to `kernal_size=1`, `padding=0`, and `dilation=1`.
* Conv3x3: Consist of a sequence of ReLU, `nn.Cinv2d` and BatchNorm. The Conv operation's parameter is fixed to `kernal_size=3`, `padding=1`, and `dilation=1`.
```eval_rst
.. autoclass:: nni.nas.pytorch.nasbench201.nasbench201_ops.ReLUConvBN
:members:
```
* SkipConnect
Call `torch.nn.Identity` to connect directly to the next cell.
* Zeroize
Generate zero tensors indicating there is no connection from the source node to the target node.
```eval_rst
.. autoclass:: nni.nas.pytorch.nasbench201.nasbench201_ops.Zero
:members:
```
<svg id="SvgjsSvg1006" width="709" height="537.0000305175781" xmlns="http://www.w3.org/2000/svg" version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns:svgjs="http://svgjs.com/svgjs"><defs id="SvgjsDefs1007"><marker id="SvgjsMarker1034" markerWidth="16" markerHeight="12" refX="16" refY="6" viewBox="0 0 16 12" orient="auto" markerUnits="userSpaceOnUse" stroke-dasharray="0,0"><path id="SvgjsPath1035" d="M0,2 L14,6 L0,11 L0,2" fill="#323232" stroke="#323232" stroke-width="2"></path></marker><marker id="SvgjsMarker1038" markerWidth="16" markerHeight="12" refX="16" refY="6" viewBox="0 0 16 12" orient="auto" markerUnits="userSpaceOnUse" stroke-dasharray="0,0"><path id="SvgjsPath1039" d="M0,2 L14,6 L0,11 L0,2" fill="#323232" stroke="#323232" stroke-width="2"></path></marker><marker id="SvgjsMarker1042" markerWidth="16" markerHeight="12" refX="16" refY="6" viewBox="0 0 16 12" orient="auto" markerUnits="userSpaceOnUse" stroke-dasharray="0,0"><path id="SvgjsPath1043" d="M0,2 L14,6 L0,11 L0,2" fill="#323232" stroke="#323232" stroke-width="2"></path></marker><marker id="SvgjsMarker1046" markerWidth="16" markerHeight="12" refX="16" refY="6" viewBox="0 0 16 12" orient="auto" markerUnits="userSpaceOnUse" stroke-dasharray="0,0"><path id="SvgjsPath1047" d="M0,2 L14,6 L0,11 L0,2" fill="#323232" stroke="#323232" stroke-width="2"></path></marker><marker id="SvgjsMarker1050" markerWidth="16" markerHeight="12" refX="16" refY="6" viewBox="0 0 16 12" orient="auto" markerUnits="userSpaceOnUse" stroke-dasharray="0,0"><path id="SvgjsPath1051" d="M0,2 L14,6 L0,11 L0,2" fill="#323232" stroke="#323232" stroke-width="2"></path></marker><marker id="SvgjsMarker1054" markerWidth="16" markerHeight="12" refX="16" refY="6" viewBox="0 0 16 12" orient="auto" markerUnits="userSpaceOnUse" stroke-dasharray="0,0"><path id="SvgjsPath1055" d="M0,2 L14,6 L0,11 L0,2" fill="#323232" stroke="#323232" stroke-width="2"></path></marker><marker id="SvgjsMarker1080" markerWidth="16" markerHeight="12" refX="16" refY="6" viewBox="0 0 16 12" orient="auto" markerUnits="userSpaceOnUse" stroke-dasharray="0,0"><path id="SvgjsPath1081" d="M0,2 L14,6 L0,11 L0,2" fill="#323232" stroke="#323232" stroke-width="2"></path></marker></defs><g id="SvgjsG1008" transform="translate(25,24.999980449676514)"><path id="SvgjsPath1009" d="M 0 35C 0 -11.666666666666666 70 -11.666666666666666 70 35C 70 81.66666666666667 0 81.66666666666667 0 35Z" stroke="rgba(50,50,50,1)" stroke-width="2" fill-opacity="1" fill="#00cc00"></path><g id="SvgjsG1010"><text id="SvgjsText1011" font-family="微软雅黑" text-anchor="middle" font-size="13px" width="50px" fill="#323232" font-weight="400" align="middle" anchor="middle" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="25.55" transform="rotate(0)"><tspan id="SvgjsTspan1012" dy="16" x="35"><tspan id="SvgjsTspan1013" style="text-decoration:;">input</tspan></tspan></text></g></g><g id="SvgjsG1014" transform="translate(267,140.9999804496765)"><path id="SvgjsPath1015" d="M 0 35C 0 -11.666666666666666 70 -11.666666666666666 70 35C 70 81.66666666666667 0 81.66666666666667 0 35Z" stroke="rgba(50,50,50,1)" stroke-width="2" fill-opacity="1" fill="#ffff00"></path><g id="SvgjsG1016"><text id="SvgjsText1017" font-family="微软雅黑" text-anchor="middle" font-size="13px" width="50px" fill="#323232" font-weight="400" align="middle" anchor="middle" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="25.55" transform="rotate(0)"><tspan id="SvgjsTspan1018" dy="16" x="35"><tspan id="SvgjsTspan1019" style="text-decoration:;">node 1</tspan></tspan></text></g></g><g id="SvgjsG1020" transform="translate(25,293.9999804496765)"><path id="SvgjsPath1021" d="M 0 35C 0 -11.666666666666666 70 -11.666666666666666 70 35C 70 81.66666666666667 0 81.66666666666667 0 35Z" stroke="rgba(50,50,50,1)" stroke-width="2" fill-opacity="1" fill="#ffff00"></path><g id="SvgjsG1022"><text id="SvgjsText1023" font-family="微软雅黑" text-anchor="middle" font-size="13px" width="50px" fill="#323232" font-weight="400" align="middle" anchor="middle" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="25.55" transform="rotate(0)"><tspan id="SvgjsTspan1024" dy="16" x="35"><tspan id="SvgjsTspan1025" style="text-decoration:;">node 2</tspan></tspan></text></g></g><g id="SvgjsG1026" transform="translate(267,441.9999804496765)"><path id="SvgjsPath1027" d="M 0 35C 0 -11.666666666666666 70 -11.666666666666666 70 35C 70 81.66666666666667 0 81.66666666666667 0 35Z" stroke="rgba(50,50,50,1)" stroke-width="2" fill-opacity="1" fill="#ffff00"></path><g id="SvgjsG1028"><text id="SvgjsText1029" font-family="微软雅黑" text-anchor="middle" font-size="13px" width="50px" fill="#323232" font-weight="400" align="middle" anchor="middle" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="25.55" transform="rotate(0)"><tspan id="SvgjsTspan1030" dy="16" x="35"><tspan id="SvgjsTspan1031" style="text-decoration:;">node 3</tspan></tspan></text></g></g><g id="SvgjsG1032"><path id="SvgjsPath1033" d="M60 94.99998044967651C 45 172.86665678024292 44 214.86665678024292 60 293.9999804496765" stroke-dasharray="8,5" stroke="#323232" stroke-width="2" fill="none" marker-end="url(#SvgjsMarker1034)"></path></g><g id="SvgjsG1036"><path id="SvgjsPath1037" d="M60 94.99998044967651C 74.7485647305657 263.5768467107039 316.7485647305658 273.42311418864915 302 441.9999804496765" stroke-dasharray="8,5" stroke="#323232" stroke-width="2" fill="none" marker-end="url(#SvgjsMarker1038)"></path></g><g id="SvgjsG1040"><path id="SvgjsPath1041" d="M302 210.9999804496765C 310.0531906298836 303.04837055335383 310.0531906298836 349.9515903459992 302 441.9999804496765" stroke-dasharray="8,5" stroke="#323232" stroke-width="2" fill="none" marker-end="url(#SvgjsMarker1042)"></path></g><g id="SvgjsG1044"><path id="SvgjsPath1045" d="M60 363.9999804496765C 68.86407754147842 465.3168503645069 310.8640775414784 340.68311053484615 302 441.9999804496765" stroke-dasharray="8,5" stroke="#323232" stroke-width="2" fill="none" marker-end="url(#SvgjsMarker1046)"></path></g><g id="SvgjsG1048"><path id="SvgjsPath1049" d="M60 94.99998044967651C 68.58773798957196 193.15827483289445 310.587737989572 42.841686066458635 302 140.9999804496765" stroke-dasharray="8,5" stroke="#323232" stroke-width="2" fill="none" marker-end="url(#SvgjsMarker1050)"></path></g><g id="SvgjsG1052"><path id="SvgjsPath1053" d="M302 210.9999804496765C 310.9190947616536 312.94570006866127 68.91909476165358 192.05426083069176 60 293.9999804496765" stroke-dasharray="8,5" stroke="#323232" stroke-width="2" fill="none" marker-end="url(#SvgjsMarker1054)"></path></g><g id="SvgjsG1056" transform="translate(382,279.8666567802429)"><path id="SvgjsPath1057" d="M 0 26C 0 -8.666666666666666 52 -8.666666666666666 52 26C 52 60.666666666666664 0 60.666666666666664 0 26Z" stroke="rgba(50,50,50,1)" stroke-width="2" fill-opacity="1" fill="#ffff00"></path><g id="SvgjsG1058"><text id="SvgjsText1059" font-family="微软雅黑" text-anchor="middle" font-size="13px" width="32px" fill="#323232" font-weight="400" align="middle" anchor="middle" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="16.55" transform="rotate(0)"></text></g></g><g id="SvgjsG1060" transform="translate(455,295.8666567802429)"><path id="SvgjsPath1061" d="M 0 0L 220 0L 220 20L 0 20Z" stroke="none" fill="none"></path><g id="SvgjsG1062"><text id="SvgjsText1063" font-family="微软雅黑" text-anchor="start" font-size="13px" width="220px" fill="#323232" font-weight="400" align="middle" anchor="start" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="-7.45" transform="rotate(0)"><tspan id="SvgjsTspan1064" dy="16" x="0"><tspan id="SvgjsTspan1065" style="text-decoration:;">choose none/one/multiple input(s) </tspan></tspan><tspan id="SvgjsTspan1066" dy="16" x="0"><tspan id="SvgjsTspan1067" style="text-decoration:;">then add them as output</tspan></tspan></text></g></g><g id="SvgjsG1068" transform="translate(382,210.9999804496765)"><path id="SvgjsPath1069" d="M 0 26C 0 -8.666666666666666 52 -8.666666666666666 52 26C 52 60.666666666666664 0 60.666666666666664 0 26Z" stroke="rgba(50,50,50,1)" stroke-width="2" fill-opacity="1" fill="#00cc00"></path><g id="SvgjsG1070"><text id="SvgjsText1071" font-family="微软雅黑" text-anchor="middle" font-size="13px" width="32px" fill="#323232" font-weight="400" align="middle" anchor="middle" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="16.55" transform="rotate(0)"></text></g></g><g id="SvgjsG1072" transform="translate(455,226.9999804496765)"><path id="SvgjsPath1073" d="M 0 0L 220 0L 220 20L 0 20Z" stroke="none" fill="none"></path><g id="SvgjsG1074"><text id="SvgjsText1075" font-family="微软雅黑" text-anchor="start" font-size="13px" width="220px" fill="#323232" font-weight="400" align="middle" anchor="start" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="0.55" transform="rotate(0)"><tspan id="SvgjsTspan1076" dy="16" x="0"><tspan id="SvgjsTspan1077" style="text-decoration:;">the output of the previous cell</tspan></tspan></text></g></g><g id="SvgjsG1078"><path id="SvgjsPath1079" d="M376 374.8666567802429L404.5 374.8666567802429L404.5 374.8666567802429L433 374.8666567802429" stroke-dasharray="8,5" stroke="#323232" stroke-width="2" fill="none" marker-end="url(#SvgjsMarker1080)"></path></g><g id="SvgjsG1082" transform="translate(463,363.9999804496765)"><path id="SvgjsPath1083" d="M 0 0L 220 0L 220 20L 0 20Z" stroke="none" fill="none"></path><g id="SvgjsG1084"><text id="SvgjsText1085" font-family="微软雅黑" text-anchor="start" font-size="13px" width="220px" fill="#323232" font-weight="400" align="middle" anchor="start" family="微软雅黑" size="13px" weight="400" font-style="" opacity="1" y="-14.95" transform="rotate(0)"><tspan id="SvgjsTspan1086" dy="16" x="0"><tspan id="SvgjsTspan1087" style="text-decoration:;">choose one operation from </tspan></tspan><tspan id="SvgjsTspan1088" dy="16" x="0"><tspan id="SvgjsTspan1089" style="text-decoration:;">MaxPool, AvgPool,, Conv1x1, </tspan></tspan><tspan id="SvgjsTspan1090" dy="16" x="0"><tspan id="SvgjsTspan1091" style="text-decoration:;">Conv3x3, SkipConnect, Zeroize</tspan></tspan></text></g></g></svg>
\ No newline at end of file
{
"0_1": "avg_pool_3x3",
"0_2": "conv_1x1",
"1_2": "skip_connect",
"0_3": "conv_1x1",
"1_3": "skip_connect",
"2_3": "skip_connect"
}
import argparse
import json
import logging
import os
import pprint
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from nni.nas.pytorch import enas
from nni.nas.pytorch.utils import AverageMeterGroup
from nni.nas.pytorch.nasbench201 import NASBench201Cell
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy, reward_accuracy
import datasets
logger = logging.getLogger('nni')
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation,
bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out, affine=bn_affine, momentum=bn_momentum,
track_running_stats=bn_track_running_stats)
)
def forward(self, x):
return self.op(x)
class ResNetBasicBlock(nn.Module):
def __init__(self, inplanes, planes, stride, bn_affine=True,
bn_momentum=0.1, bn_track_running_stats=True):
super(ResNetBasicBlock, self).__init__()
assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1, bn_affine, bn_momentum, bn_track_running_stats)
self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1, bn_affine, bn_momentum, bn_track_running_stats)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1, bn_affine, bn_momentum, bn_track_running_stats)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
inputs = self.downsample(inputs)
return inputs + basicblock
class NASBench201Network(nn.Module):
def __init__(self, stem_out_channels, num_modules_per_stack, bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True):
super(NASBench201Network, self).__init__()
self.channels = C = stem_out_channels
self.num_modules = N = num_modules_per_stack
self.num_labels = 10
self.bn_momentum = bn_momentum
self.bn_affine = bn_affine
self.bn_track_running_stats = bn_track_running_stats
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C, momentum=self.bn_momentum)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for i, (C_curr, reduction) in enumerate(zip(layer_channels, layer_reductions)):
if reduction:
cell = ResNetBasicBlock(C_prev, C_curr, 2, self.bn_affine, self.bn_momentum, self.bn_track_running_stats)
else:
cell = NASBench201Cell(i, C_prev, C_curr, 1, self.bn_affine, self.bn_momentum, self.bn_track_running_stats)
self.cells.append(cell)
C_prev = C_curr
self.lastact = nn.Sequential(
nn.BatchNorm2d(C_prev, momentum=self.bn_momentum),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels)
def forward(self, inputs):
feature = self.stem(inputs)
for cell in self.cells:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
def train(args, model, train_dataloader, valid_dataloader, criterion, optimizer, device):
model = model.to(device)
model.train()
for epoch in range(args.epochs):
for batch_idx, (data, target) in enumerate(train_dataloader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_frequency == 0:
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_dataloader.dataset),
100. * batch_idx / len(train_dataloader), loss.item()))
model.eval()
correct = 0
test_loss = 0.0
for data, target in valid_dataloader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(valid_dataloader.dataset)
accuracy = 100. * correct / len(valid_dataloader.dataset)
logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct,
len(valid_dataloader.dataset), accuracy))
model.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser("nb201")
parser.add_argument('--stem_out_channels', default=16, type=int)
parser.add_argument('--unrolled', default=False, action='store_true')
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--num_modules_per_stack', default=5, type=int)
parser.add_argument('--log-frequency', default=10, type=int)
parser.add_argument('--bn_momentum', default=0.1, type=int)
parser.add_argument('--bn_affine', default=True, type=bool)
parser.add_argument('--bn_track_running_stats', default=True, type=bool)
parser.add_argument('--arch', default=None, help='json file which should meet requirements in NAS-Bench-201')
parser.add_argument('--visualization', default=False, action='store_true')
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = NASBench201Network(stem_out_channels=args.stem_out_channels,
num_modules_per_stack=args.num_modules_per_stack,
bn_affine=args.bn_affine,
bn_momentum=args.bn_momentum,
bn_track_running_stats=args.bn_track_running_stats)
optim = torch.optim.SGD(model.parameters(), 0.025)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
criterion = nn.CrossEntropyLoss()
if args.arch is not None:
logger.info('model retraining...')
with open(args.arch, 'r') as f:
arch = json.load(f)
for trial in query_nb201_trial_stats(arch, 200, 'cifar100'):
pprint.pprint(trial)
apply_fixed_architecture(model, args.arch)
dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=0)
dataloader_valid = DataLoader(dataset_valid, batch_size=args.batch_size, shuffle=True, num_workers=0)
train(args, model, dataloader_train, dataloader_valid, criterion, optim,
torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
exit(0)
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
reward_function=reward_accuracy,
optimizer=optim,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size,
num_epochs=args.epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency)
if args.visualization:
trainer.enable_visualization()
trainer.train()
from .nasbench201 import NASBench201Cell
from collections import OrderedDict
import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice
from .nasbench201_ops import Pooling, ReLUConvBN, Zero, FactorizedReduce
class NASBench201Cell(nn.Module):
"""
Builtin cell structure of NAS Bench 201. One cell contains four nodes. The First node serves as an input node
accepting the output of the previous cell. And other nodes connect to all previous nodes with an edge that
represents an operation chosen from a set to transform the tensor from the source node to the target node.
Every node accepts all its inputs and adds them as its output.
Parameters
---
cell_id: str
the name of this cell
C_in: int
the number of input channels of the cell
C_out: int
the number of output channels of the cell
stride: int
stride of all convolution operations in the cell
bn_affine: bool
If set to ``True``, all ``torch.nn.BatchNorm2d`` in this cell will have learnable affine parameters. Default: True
bn_momentum: float
the value used for the running_mean and running_var computation. Default: 0.1
bn_track_running_stats: bool
When set to ``True``, all ``torch.nn.BatchNorm2d`` in this cell tracks the running mean and variance. Default: True
"""
def __init__(self, cell_id, C_in, C_out, stride, bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True):
super(NASBench201Cell, self).__init__()
self.NUM_NODES = 4
self.layers = nn.ModuleList()
OPS = lambda layer_idx: OrderedDict([
("none", Zero(C_in, C_out, stride)),
("avg_pool_3x3", Pooling(C_in, C_out, stride if layer_idx == 0 else 1, bn_affine, bn_momentum,
bn_track_running_stats)),
("conv_3x3", ReLUConvBN(C_in, C_out, 3, stride if layer_idx == 0 else 1, 1, 1, bn_affine, bn_momentum,
bn_track_running_stats)),
("conv_1x1", ReLUConvBN(C_in, C_out, 1, stride if layer_idx == 0 else 1, 0, 1, bn_affine, bn_momentum,
bn_track_running_stats)),
("skip_connect", nn.Identity() if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride if layer_idx == 0 else 1, bn_affine, bn_momentum,
bn_track_running_stats))
])
for i in range(self.NUM_NODES):
node_ops = nn.ModuleList()
for j in range(0, i):
node_ops.append(LayerChoice(OPS(j), key="%d_%d" % (j, i), reduction="mean"))
self.layers.append(node_ops)
self.in_dim = C_in
self.out_dim = C_out
self.cell_id = cell_id
def forward(self, input): # pylint: disable=W0622
"""
Parameters
---
input: torch.tensor
the output of the previous layer
"""
nodes = [input]
for i in range(1, self.NUM_NODES):
node_feature = sum(self.layers[i][k](nodes[k]) for k in range(i))
nodes.append(node_feature)
return nodes[-1]
import torch
import torch.nn as nn
class ReLUConvBN(nn.Module):
"""
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
dilation: int
spacing between kernel elements
bn_affine: bool
If set to ``True``, ``torch.nn.BatchNorm2d`` will have learnable affine parameters. Default: True
bn_momentun: float
the value used for the running_mean and running_var computation. Default: 0.1
bn_track_running_stats: bool
When set to ``True``, ``torch.nn.BatchNorm2d`` tracks the running mean and variance. Default: True
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation,
bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out, affine=bn_affine, momentum=bn_momentum,
track_running_stats=bn_track_running_stats)
)
def forward(self, x):
"""
Parameters
---
x: torch.Tensor
input tensor
"""
return self.op(x)
class Pooling(nn.Module):
"""
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
stride: int
stride of the convolution
bn_affine: bool
If set to ``True``, ``torch.nn.BatchNorm2d`` will have learnable affine parameters. Default: True
bn_momentun: float
the value used for the running_mean and running_var computation. Default: 0.1
bn_track_running_stats: bool
When set to ``True``, ``torch.nn.BatchNorm2d`` tracks the running mean and variance. Default: True
"""
def __init__(self, C_in, C_out, stride, bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True):
super(Pooling, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 0,
bn_affine, bn_momentum, bn_track_running_stats)
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
def forward(self, x):
"""
Parameters
---
x: torch.Tensor
input tensor
"""
if self.preprocess:
x = self.preprocess(x)
return self.op(x)
class Zero(nn.Module):
"""
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
stride: int
stride of the convolution
"""
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
"""
Parameters
---
x: torch.Tensor
input tensor
"""
if self.C_in == self.C_out:
if self.stride == 1:
return x.mul(0.)
else:
return x[:, :, ::self.stride, ::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, bn_affine=True, bn_momentum=0.1,
bn_track_running_stats=True):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else:
raise ValueError("Invalid stride : {:}".format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=bn_affine, momentum=bn_momentum,
track_running_stats=bn_track_running_stats)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment