Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
542a660d
Unverified
Commit
542a660d
authored
Jul 14, 2021
by
Yuge Zhang
Committed by
GitHub
Jul 14, 2021
Browse files
[Retiarii] NAS-Bench-201 (#3920)
parent
7eedec46
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
397 additions
and
1 deletion
+397
-1
examples/nas/multi-trial/nasbench201/base_ops.py
examples/nas/multi-trial/nasbench201/base_ops.py
+138
-0
examples/nas/multi-trial/nasbench201/network.py
examples/nas/multi-trial/nasbench201/network.py
+162
-0
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+76
-1
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+21
-0
No files found.
examples/nas/multi-trial/nasbench201/base_ops.py
0 → 100644
View file @
542a660d
import
torch
import
torch.nn
as
nn
OPS_WITH_STRIDE
=
{
'none'
:
lambda
C_in
,
C_out
,
stride
:
Zero
(
C_in
,
C_out
,
stride
),
'avg_pool_3x3'
:
lambda
C_in
,
C_out
,
stride
:
Pooling
(
C_in
,
C_out
,
stride
,
'avg'
),
'max_pool_3x3'
:
lambda
C_in
,
C_out
,
stride
:
Pooling
(
C_in
,
C_out
,
stride
,
'max'
),
'conv_3x3'
:
lambda
C_in
,
C_out
,
stride
:
ReLUConvBN
(
C_in
,
C_out
,
(
3
,
3
),
(
stride
,
stride
),
(
1
,
1
),
(
1
,
1
)),
'conv_1x1'
:
lambda
C_in
,
C_out
,
stride
:
ReLUConvBN
(
C_in
,
C_out
,
(
1
,
1
),
(
stride
,
stride
),
(
0
,
0
),
(
1
,
1
)),
'skip_connect'
:
lambda
C_in
,
C_out
,
stride
:
nn
.
Identity
()
if
stride
==
1
and
C_in
==
C_out
else
FactorizedReduce
(
C_in
,
C_out
,
stride
),
}
PRIMITIVES
=
[
'none'
,
'skip_connect'
,
'conv_1x1'
,
'conv_3x3'
,
'avg_pool_3x3'
]
class
ReLUConvBN
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
):
super
(
ReLUConvBN
,
self
).
__init__
()
self
.
op
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
)
)
def
forward
(
self
,
x
):
return
self
.
op
(
x
)
class
SepConv
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
):
super
(
SepConv
,
self
).
__init__
()
self
.
op
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
False
),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
),
)
def
forward
(
self
,
x
):
return
self
.
op
(
x
)
class
Pooling
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
stride
,
mode
):
super
(
Pooling
,
self
).
__init__
()
if
C_in
==
C_out
:
self
.
preprocess
=
None
else
:
self
.
preprocess
=
ReLUConvBN
(
C_in
,
C_out
,
1
,
1
,
0
,
1
)
if
mode
==
'avg'
:
self
.
op
=
nn
.
AvgPool2d
(
3
,
stride
=
stride
,
padding
=
1
,
count_include_pad
=
False
)
elif
mode
==
'max'
:
self
.
op
=
nn
.
MaxPool2d
(
3
,
stride
=
stride
,
padding
=
1
)
else
:
raise
ValueError
(
'Invalid mode={:} in Pooling'
.
format
(
mode
))
def
forward
(
self
,
x
):
if
self
.
preprocess
:
x
=
self
.
preprocess
(
x
)
return
self
.
op
(
x
)
class
Zero
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
stride
):
super
(
Zero
,
self
).
__init__
()
self
.
C_in
=
C_in
self
.
C_out
=
C_out
self
.
stride
=
stride
self
.
is_zero
=
True
def
forward
(
self
,
x
):
if
self
.
C_in
==
self
.
C_out
:
if
self
.
stride
==
1
:
return
x
.
mul
(
0.
)
else
:
return
x
[:,
:,
::
self
.
stride
,
::
self
.
stride
].
mul
(
0.
)
else
:
shape
=
list
(
x
.
shape
)
shape
[
1
]
=
self
.
C_out
zeros
=
x
.
new_zeros
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
zeros
class
FactorizedReduce
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
stride
):
super
(
FactorizedReduce
,
self
).
__init__
()
self
.
stride
=
stride
self
.
C_in
=
C_in
self
.
C_out
=
C_out
self
.
relu
=
nn
.
ReLU
(
inplace
=
False
)
if
stride
==
2
:
C_outs
=
[
C_out
//
2
,
C_out
-
C_out
//
2
]
self
.
convs
=
nn
.
ModuleList
()
for
i
in
range
(
2
):
self
.
convs
.
append
(
nn
.
Conv2d
(
C_in
,
C_outs
[
i
],
1
,
stride
=
stride
,
padding
=
0
,
bias
=
False
))
self
.
pad
=
nn
.
ConstantPad2d
((
0
,
1
,
0
,
1
),
0
)
else
:
raise
ValueError
(
'Invalid stride : {:}'
.
format
(
stride
))
self
.
bn
=
nn
.
BatchNorm2d
(
C_out
)
def
forward
(
self
,
x
):
x
=
self
.
relu
(
x
)
y
=
self
.
pad
(
x
)
out
=
torch
.
cat
([
self
.
convs
[
0
](
x
),
self
.
convs
[
1
](
y
[:,
:,
1
:,
1
:])],
dim
=
1
)
out
=
self
.
bn
(
out
)
return
out
class
ResNetBasicblock
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
):
super
(
ResNetBasicblock
,
self
).
__init__
()
assert
stride
==
1
or
stride
==
2
,
'invalid stride {:}'
.
format
(
stride
)
self
.
conv_a
=
ReLUConvBN
(
inplanes
,
planes
,
3
,
stride
,
1
,
1
)
self
.
conv_b
=
ReLUConvBN
(
planes
,
planes
,
3
,
1
,
1
,
1
)
if
stride
==
2
:
self
.
downsample
=
nn
.
Sequential
(
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
),
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
))
elif
inplanes
!=
planes
:
self
.
downsample
=
ReLUConvBN
(
inplanes
,
planes
,
1
,
1
,
0
,
1
)
else
:
self
.
downsample
=
None
self
.
in_dim
=
inplanes
self
.
out_dim
=
planes
self
.
stride
=
stride
self
.
num_conv
=
2
def
forward
(
self
,
inputs
):
basicblock
=
self
.
conv_a
(
inputs
)
basicblock
=
self
.
conv_b
(
basicblock
)
if
self
.
downsample
is
not
None
:
inputs
=
self
.
downsample
(
inputs
)
# residual
return
inputs
+
basicblock
examples/nas/multi-trial/nasbench201/network.py
0 → 100644
View file @
542a660d
import
click
import
nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
torch.nn
as
nn
import
torchmetrics
from
nni.retiarii
import
model_wrapper
,
serialize
,
serialize_cls
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.nn.pytorch
import
NasBench201Cell
from
nni.retiarii.strategy
import
Random
from
pytorch_lightning.callbacks
import
LearningRateMonitor
from
timm.optim
import
RMSpropTF
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR100
from
base_ops
import
ResNetBasicblock
,
PRIMITIVES
,
OPS_WITH_STRIDE
@
model_wrapper
class
NasBench201
(
nn
.
Module
):
def
__init__
(
self
,
stem_out_channels
:
int
=
16
,
num_modules_per_stack
:
int
=
5
,
num_labels
:
int
=
100
):
super
().
__init__
()
self
.
channels
=
C
=
stem_out_channels
self
.
num_modules
=
N
=
num_modules_per_stack
self
.
num_labels
=
num_labels
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
C
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
C
)
)
layer_channels
=
[
C
]
*
N
+
[
C
*
2
]
+
[
C
*
2
]
*
N
+
[
C
*
4
]
+
[
C
*
4
]
*
N
layer_reductions
=
[
False
]
*
N
+
[
True
]
+
[
False
]
*
N
+
[
True
]
+
[
False
]
*
N
C_prev
=
C
self
.
cells
=
nn
.
ModuleList
()
for
C_curr
,
reduction
in
zip
(
layer_channels
,
layer_reductions
):
if
reduction
:
cell
=
ResNetBasicblock
(
C_prev
,
C_curr
,
2
)
else
:
cell
=
NasBench201Cell
({
prim
:
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
for
prim
in
PRIMITIVES
},
C_prev
,
C_curr
,
label
=
'cell'
)
self
.
cells
.
append
(
cell
)
C_prev
=
C_curr
self
.
lastact
=
nn
.
Sequential
(
nn
.
BatchNorm2d
(
C_prev
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
classifier
=
nn
.
Linear
(
C_prev
,
self
.
num_labels
)
def
forward
(
self
,
inputs
):
feature
=
self
.
stem
(
inputs
)
for
cell
in
self
.
cells
:
feature
=
cell
(
feature
)
out
=
self
.
lastact
(
feature
)
out
=
self
.
global_pooling
(
out
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
logits
=
self
.
classifier
(
out
)
return
logits
class
AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
class
NasBench201TrainingModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
max_epochs
=
200
,
learning_rate
=
0.1
,
weight_decay
=
5e-4
):
super
().
__init__
()
self
.
save_hyperparameters
(
'learning_rate'
,
'weight_decay'
,
'max_epochs'
)
self
.
criterion
=
nn
.
CrossEntropyLoss
()
self
.
accuracy
=
AccuracyWithLogits
()
def
forward
(
self
,
x
):
y_hat
=
self
.
model
(
x
)
return
y_hat
def
training_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
y_hat
=
self
(
x
)
loss
=
self
.
criterion
(
y_hat
,
y
)
self
.
log
(
'train_loss'
,
loss
,
prog_bar
=
True
)
self
.
log
(
'train_accuracy'
,
self
.
accuracy
(
y_hat
,
y
),
prog_bar
=
True
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
y_hat
=
self
(
x
)
self
.
log
(
'val_loss'
,
self
.
criterion
(
y_hat
,
y
),
prog_bar
=
True
)
self
.
log
(
'val_accuracy'
,
self
.
accuracy
(
y_hat
,
y
),
prog_bar
=
True
)
def
configure_optimizers
(
self
):
optimizer
=
RMSpropTF
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
,
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
return
{
'optimizer'
:
optimizer
,
'scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
}
def
on_validation_epoch_end
(
self
):
nni
.
report_intermediate_result
(
self
.
trainer
.
callback_metrics
[
'val_accuracy'
].
item
())
def
teardown
(
self
,
stage
):
if
stage
==
'fit'
:
nni
.
report_final_result
(
self
.
trainer
.
callback_metrics
[
'val_accuracy'
].
item
())
@
click
.
command
()
@
click
.
option
(
'--epochs'
,
default
=
12
,
help
=
'Training length.'
)
@
click
.
option
(
'--batch_size'
,
default
=
256
,
help
=
'Batch size.'
)
@
click
.
option
(
'--port'
,
default
=
8081
,
help
=
'On which port the experiment is run.'
)
def
_multi_trial_test
(
epochs
,
batch_size
,
port
):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf
=
[
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
()
]
normalize
=
[
transforms
.
ToTensor
(),
transforms
.
Normalize
([
x
/
255
for
x
in
[
129.3
,
124.1
,
112.4
]],
[
x
/
255
for
x
in
[
68.2
,
65.4
,
70.4
]])
]
train_dataset
=
serialize
(
CIFAR100
,
'data'
,
train
=
True
,
download
=
True
,
transform
=
transforms
.
Compose
(
transf
+
normalize
))
test_dataset
=
serialize
(
CIFAR100
,
'data'
,
train
=
False
,
transform
=
transforms
.
Compose
(
normalize
))
# specify training hyper-parameters
training_module
=
NasBench201TrainingModule
(
max_epochs
=
epochs
)
# FIXME: need to fix a bug in serializer for this to work
# lr_monitor = serialize(LearningRateMonitor, logging_interval='step')
trainer
=
pl
.
Trainer
(
max_epochs
=
epochs
,
gpus
=
1
)
lightning
=
pl
.
Lightning
(
lightning_module
=
training_module
,
trainer
=
trainer
,
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
batch_size
),
)
strategy
=
Random
()
model
=
NasBench201
()
exp
=
RetiariiExperiment
(
model
,
lightning
,
[],
strategy
)
exp_config
=
RetiariiExeConfig
(
'local'
)
exp_config
.
trial_concurrency
=
2
exp_config
.
max_trial_number
=
20
exp_config
.
trial_gpu_number
=
1
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
port
)
if
__name__
==
'__main__'
:
_multi_trial_test
()
nni/retiarii/nn/pytorch/component.py
View file @
542a660d
import
copy
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
import
torch
...
...
@@ -12,7 +13,7 @@ from .utils import generate_new_label, get_fixed_value
from
...utils
import
NoContextError
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
]
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
,
'NasBench201Cell'
]
class
Repeat
(
nn
.
Module
):
...
...
@@ -147,3 +148,77 @@ class Cell(nn.Module):
current_state
=
torch
.
sum
(
torch
.
stack
(
current_state
),
0
)
states
.
append
(
current_state
)
return
torch
.
cat
(
states
[
self
.
num_predecessors
:],
1
)
class
NasBench201Cell
(
nn
.
Module
):
"""
Cell structure that is proposed in NAS-Bench-201 [nasbench201]_ .
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
References
----------
.. [nasbench201] Dong, X. and Yang, Y., 2020. Nas-bench-201: Extending the scope of reproducible neural architecture search.
arXiv preprint arXiv:2001.00326.
"""
@
staticmethod
def
_make_dict
(
x
):
if
isinstance
(
x
,
list
):
return
OrderedDict
([(
str
(
i
),
t
)
for
i
,
t
in
enumerate
(
x
)])
return
OrderedDict
(
x
)
def
__init__
(
self
,
op_candidates
:
List
[
Callable
[[
int
,
int
],
nn
.
Module
]],
in_features
:
int
,
out_features
:
int
,
num_tensors
:
int
=
4
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
layers
=
nn
.
ModuleList
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
num_tensors
=
num_tensors
op_candidates
=
self
.
_make_dict
(
op_candidates
)
for
tid
in
range
(
1
,
num_tensors
):
node_ops
=
nn
.
ModuleList
()
for
j
in
range
(
tid
):
inp
=
in_features
if
j
==
0
else
out_features
op_choices
=
OrderedDict
([(
key
,
cls
(
inp
,
out_features
))
for
key
,
cls
in
op_candidates
.
items
()])
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
__
{
j
}
_
{
tid
}
'
))
self
.
layers
.
append
(
node_ops
)
def
forward
(
self
,
inputs
):
tensors
=
[
inputs
]
for
layer
in
self
.
layers
:
current_tensor
=
[]
for
i
,
op
in
enumerate
(
layer
):
current_tensor
.
append
(
op
(
tensors
[
i
]))
current_tensor
=
torch
.
sum
(
torch
.
stack
(
current_tensor
),
0
)
tensors
.
append
(
current_tensor
)
return
tensors
[
-
1
]
test/ut/retiarii/test_highlevel_apis.py
View file @
542a660d
...
...
@@ -493,6 +493,27 @@ class GraphIR(unittest.TestCase):
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
def
test_nasbench201_cell
(
self
):
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
NasBench201Cell
([
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
),
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
,
bias
=
False
)
],
10
,
16
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
16
]))
class
Python
(
GraphIR
):
def
_get_converted_pytorch_model
(
self
,
model_ir
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment