Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
542a660d
Unverified
Commit
542a660d
authored
Jul 14, 2021
by
Yuge Zhang
Committed by
GitHub
Jul 14, 2021
Browse files
[Retiarii] NAS-Bench-201 (#3920)
parent
7eedec46
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
397 additions
and
1 deletion
+397
-1
examples/nas/multi-trial/nasbench201/base_ops.py
examples/nas/multi-trial/nasbench201/base_ops.py
+138
-0
examples/nas/multi-trial/nasbench201/network.py
examples/nas/multi-trial/nasbench201/network.py
+162
-0
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+76
-1
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+21
-0
No files found.
examples/nas/multi-trial/nasbench201/base_ops.py
0 → 100644
View file @
542a660d
import
torch
import
torch.nn
as
nn
OPS_WITH_STRIDE
=
{
'none'
:
lambda
C_in
,
C_out
,
stride
:
Zero
(
C_in
,
C_out
,
stride
),
'avg_pool_3x3'
:
lambda
C_in
,
C_out
,
stride
:
Pooling
(
C_in
,
C_out
,
stride
,
'avg'
),
'max_pool_3x3'
:
lambda
C_in
,
C_out
,
stride
:
Pooling
(
C_in
,
C_out
,
stride
,
'max'
),
'conv_3x3'
:
lambda
C_in
,
C_out
,
stride
:
ReLUConvBN
(
C_in
,
C_out
,
(
3
,
3
),
(
stride
,
stride
),
(
1
,
1
),
(
1
,
1
)),
'conv_1x1'
:
lambda
C_in
,
C_out
,
stride
:
ReLUConvBN
(
C_in
,
C_out
,
(
1
,
1
),
(
stride
,
stride
),
(
0
,
0
),
(
1
,
1
)),
'skip_connect'
:
lambda
C_in
,
C_out
,
stride
:
nn
.
Identity
()
if
stride
==
1
and
C_in
==
C_out
else
FactorizedReduce
(
C_in
,
C_out
,
stride
),
}
PRIMITIVES
=
[
'none'
,
'skip_connect'
,
'conv_1x1'
,
'conv_3x3'
,
'avg_pool_3x3'
]
class
ReLUConvBN
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
):
super
(
ReLUConvBN
,
self
).
__init__
()
self
.
op
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
)
)
def
forward
(
self
,
x
):
return
self
.
op
(
x
)
class
SepConv
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
):
super
(
SepConv
,
self
).
__init__
()
self
.
op
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
False
),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
=
kernel_size
,
stride
=
stride
,
padding
=
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
),
)
def
forward
(
self
,
x
):
return
self
.
op
(
x
)
class
Pooling
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
stride
,
mode
):
super
(
Pooling
,
self
).
__init__
()
if
C_in
==
C_out
:
self
.
preprocess
=
None
else
:
self
.
preprocess
=
ReLUConvBN
(
C_in
,
C_out
,
1
,
1
,
0
,
1
)
if
mode
==
'avg'
:
self
.
op
=
nn
.
AvgPool2d
(
3
,
stride
=
stride
,
padding
=
1
,
count_include_pad
=
False
)
elif
mode
==
'max'
:
self
.
op
=
nn
.
MaxPool2d
(
3
,
stride
=
stride
,
padding
=
1
)
else
:
raise
ValueError
(
'Invalid mode={:} in Pooling'
.
format
(
mode
))
def
forward
(
self
,
x
):
if
self
.
preprocess
:
x
=
self
.
preprocess
(
x
)
return
self
.
op
(
x
)
class
Zero
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
stride
):
super
(
Zero
,
self
).
__init__
()
self
.
C_in
=
C_in
self
.
C_out
=
C_out
self
.
stride
=
stride
self
.
is_zero
=
True
def
forward
(
self
,
x
):
if
self
.
C_in
==
self
.
C_out
:
if
self
.
stride
==
1
:
return
x
.
mul
(
0.
)
else
:
return
x
[:,
:,
::
self
.
stride
,
::
self
.
stride
].
mul
(
0.
)
else
:
shape
=
list
(
x
.
shape
)
shape
[
1
]
=
self
.
C_out
zeros
=
x
.
new_zeros
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
return
zeros
class
FactorizedReduce
(
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
stride
):
super
(
FactorizedReduce
,
self
).
__init__
()
self
.
stride
=
stride
self
.
C_in
=
C_in
self
.
C_out
=
C_out
self
.
relu
=
nn
.
ReLU
(
inplace
=
False
)
if
stride
==
2
:
C_outs
=
[
C_out
//
2
,
C_out
-
C_out
//
2
]
self
.
convs
=
nn
.
ModuleList
()
for
i
in
range
(
2
):
self
.
convs
.
append
(
nn
.
Conv2d
(
C_in
,
C_outs
[
i
],
1
,
stride
=
stride
,
padding
=
0
,
bias
=
False
))
self
.
pad
=
nn
.
ConstantPad2d
((
0
,
1
,
0
,
1
),
0
)
else
:
raise
ValueError
(
'Invalid stride : {:}'
.
format
(
stride
))
self
.
bn
=
nn
.
BatchNorm2d
(
C_out
)
def
forward
(
self
,
x
):
x
=
self
.
relu
(
x
)
y
=
self
.
pad
(
x
)
out
=
torch
.
cat
([
self
.
convs
[
0
](
x
),
self
.
convs
[
1
](
y
[:,
:,
1
:,
1
:])],
dim
=
1
)
out
=
self
.
bn
(
out
)
return
out
class
ResNetBasicblock
(
nn
.
Module
):
def
__init__
(
self
,
inplanes
,
planes
,
stride
):
super
(
ResNetBasicblock
,
self
).
__init__
()
assert
stride
==
1
or
stride
==
2
,
'invalid stride {:}'
.
format
(
stride
)
self
.
conv_a
=
ReLUConvBN
(
inplanes
,
planes
,
3
,
stride
,
1
,
1
)
self
.
conv_b
=
ReLUConvBN
(
planes
,
planes
,
3
,
1
,
1
,
1
)
if
stride
==
2
:
self
.
downsample
=
nn
.
Sequential
(
nn
.
AvgPool2d
(
kernel_size
=
2
,
stride
=
2
,
padding
=
0
),
nn
.
Conv2d
(
inplanes
,
planes
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
))
elif
inplanes
!=
planes
:
self
.
downsample
=
ReLUConvBN
(
inplanes
,
planes
,
1
,
1
,
0
,
1
)
else
:
self
.
downsample
=
None
self
.
in_dim
=
inplanes
self
.
out_dim
=
planes
self
.
stride
=
stride
self
.
num_conv
=
2
def
forward
(
self
,
inputs
):
basicblock
=
self
.
conv_a
(
inputs
)
basicblock
=
self
.
conv_b
(
basicblock
)
if
self
.
downsample
is
not
None
:
inputs
=
self
.
downsample
(
inputs
)
# residual
return
inputs
+
basicblock
examples/nas/multi-trial/nasbench201/network.py
0 → 100644
View file @
542a660d
import
click
import
nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
torch.nn
as
nn
import
torchmetrics
from
nni.retiarii
import
model_wrapper
,
serialize
,
serialize_cls
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.nn.pytorch
import
NasBench201Cell
from
nni.retiarii.strategy
import
Random
from
pytorch_lightning.callbacks
import
LearningRateMonitor
from
timm.optim
import
RMSpropTF
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR100
from
base_ops
import
ResNetBasicblock
,
PRIMITIVES
,
OPS_WITH_STRIDE
@
model_wrapper
class
NasBench201
(
nn
.
Module
):
def
__init__
(
self
,
stem_out_channels
:
int
=
16
,
num_modules_per_stack
:
int
=
5
,
num_labels
:
int
=
100
):
super
().
__init__
()
self
.
channels
=
C
=
stem_out_channels
self
.
num_modules
=
N
=
num_modules_per_stack
self
.
num_labels
=
num_labels
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
3
,
C
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
C
)
)
layer_channels
=
[
C
]
*
N
+
[
C
*
2
]
+
[
C
*
2
]
*
N
+
[
C
*
4
]
+
[
C
*
4
]
*
N
layer_reductions
=
[
False
]
*
N
+
[
True
]
+
[
False
]
*
N
+
[
True
]
+
[
False
]
*
N
C_prev
=
C
self
.
cells
=
nn
.
ModuleList
()
for
C_curr
,
reduction
in
zip
(
layer_channels
,
layer_reductions
):
if
reduction
:
cell
=
ResNetBasicblock
(
C_prev
,
C_curr
,
2
)
else
:
cell
=
NasBench201Cell
({
prim
:
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
for
prim
in
PRIMITIVES
},
C_prev
,
C_curr
,
label
=
'cell'
)
self
.
cells
.
append
(
cell
)
C_prev
=
C_curr
self
.
lastact
=
nn
.
Sequential
(
nn
.
BatchNorm2d
(
C_prev
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
classifier
=
nn
.
Linear
(
C_prev
,
self
.
num_labels
)
def
forward
(
self
,
inputs
):
feature
=
self
.
stem
(
inputs
)
for
cell
in
self
.
cells
:
feature
=
cell
(
feature
)
out
=
self
.
lastact
(
feature
)
out
=
self
.
global_pooling
(
out
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
logits
=
self
.
classifier
(
out
)
return
logits
class
AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
class
NasBench201TrainingModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
max_epochs
=
200
,
learning_rate
=
0.1
,
weight_decay
=
5e-4
):
super
().
__init__
()
self
.
save_hyperparameters
(
'learning_rate'
,
'weight_decay'
,
'max_epochs'
)
self
.
criterion
=
nn
.
CrossEntropyLoss
()
self
.
accuracy
=
AccuracyWithLogits
()
def
forward
(
self
,
x
):
y_hat
=
self
.
model
(
x
)
return
y_hat
def
training_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
y_hat
=
self
(
x
)
loss
=
self
.
criterion
(
y_hat
,
y
)
self
.
log
(
'train_loss'
,
loss
,
prog_bar
=
True
)
self
.
log
(
'train_accuracy'
,
self
.
accuracy
(
y_hat
,
y
),
prog_bar
=
True
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
y_hat
=
self
(
x
)
self
.
log
(
'val_loss'
,
self
.
criterion
(
y_hat
,
y
),
prog_bar
=
True
)
self
.
log
(
'val_accuracy'
,
self
.
accuracy
(
y_hat
,
y
),
prog_bar
=
True
)
def
configure_optimizers
(
self
):
optimizer
=
RMSpropTF
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
,
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
return
{
'optimizer'
:
optimizer
,
'scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
}
def
on_validation_epoch_end
(
self
):
nni
.
report_intermediate_result
(
self
.
trainer
.
callback_metrics
[
'val_accuracy'
].
item
())
def
teardown
(
self
,
stage
):
if
stage
==
'fit'
:
nni
.
report_final_result
(
self
.
trainer
.
callback_metrics
[
'val_accuracy'
].
item
())
@
click
.
command
()
@
click
.
option
(
'--epochs'
,
default
=
12
,
help
=
'Training length.'
)
@
click
.
option
(
'--batch_size'
,
default
=
256
,
help
=
'Batch size.'
)
@
click
.
option
(
'--port'
,
default
=
8081
,
help
=
'On which port the experiment is run.'
)
def
_multi_trial_test
(
epochs
,
batch_size
,
port
):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf
=
[
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
()
]
normalize
=
[
transforms
.
ToTensor
(),
transforms
.
Normalize
([
x
/
255
for
x
in
[
129.3
,
124.1
,
112.4
]],
[
x
/
255
for
x
in
[
68.2
,
65.4
,
70.4
]])
]
train_dataset
=
serialize
(
CIFAR100
,
'data'
,
train
=
True
,
download
=
True
,
transform
=
transforms
.
Compose
(
transf
+
normalize
))
test_dataset
=
serialize
(
CIFAR100
,
'data'
,
train
=
False
,
transform
=
transforms
.
Compose
(
normalize
))
# specify training hyper-parameters
training_module
=
NasBench201TrainingModule
(
max_epochs
=
epochs
)
# FIXME: need to fix a bug in serializer for this to work
# lr_monitor = serialize(LearningRateMonitor, logging_interval='step')
trainer
=
pl
.
Trainer
(
max_epochs
=
epochs
,
gpus
=
1
)
lightning
=
pl
.
Lightning
(
lightning_module
=
training_module
,
trainer
=
trainer
,
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
batch_size
),
)
strategy
=
Random
()
model
=
NasBench201
()
exp
=
RetiariiExperiment
(
model
,
lightning
,
[],
strategy
)
exp_config
=
RetiariiExeConfig
(
'local'
)
exp_config
.
trial_concurrency
=
2
exp_config
.
max_trial_number
=
20
exp_config
.
trial_gpu_number
=
1
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
port
)
if
__name__
==
'__main__'
:
_multi_trial_test
()
nni/retiarii/nn/pytorch/component.py
View file @
542a660d
import
copy
import
copy
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
import
torch
import
torch
...
@@ -12,7 +13,7 @@ from .utils import generate_new_label, get_fixed_value
...
@@ -12,7 +13,7 @@ from .utils import generate_new_label, get_fixed_value
from
...utils
import
NoContextError
from
...utils
import
NoContextError
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
]
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
,
'NasBench201Cell'
]
class
Repeat
(
nn
.
Module
):
class
Repeat
(
nn
.
Module
):
...
@@ -147,3 +148,77 @@ class Cell(nn.Module):
...
@@ -147,3 +148,77 @@ class Cell(nn.Module):
current_state
=
torch
.
sum
(
torch
.
stack
(
current_state
),
0
)
current_state
=
torch
.
sum
(
torch
.
stack
(
current_state
),
0
)
states
.
append
(
current_state
)
states
.
append
(
current_state
)
return
torch
.
cat
(
states
[
self
.
num_predecessors
:],
1
)
return
torch
.
cat
(
states
[
self
.
num_predecessors
:],
1
)
class
NasBench201Cell
(
nn
.
Module
):
"""
Cell structure that is proposed in NAS-Bench-201 [nasbench201]_ .
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
References
----------
.. [nasbench201] Dong, X. and Yang, Y., 2020. Nas-bench-201: Extending the scope of reproducible neural architecture search.
arXiv preprint arXiv:2001.00326.
"""
@
staticmethod
def
_make_dict
(
x
):
if
isinstance
(
x
,
list
):
return
OrderedDict
([(
str
(
i
),
t
)
for
i
,
t
in
enumerate
(
x
)])
return
OrderedDict
(
x
)
def
__init__
(
self
,
op_candidates
:
List
[
Callable
[[
int
,
int
],
nn
.
Module
]],
in_features
:
int
,
out_features
:
int
,
num_tensors
:
int
=
4
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
layers
=
nn
.
ModuleList
()
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
num_tensors
=
num_tensors
op_candidates
=
self
.
_make_dict
(
op_candidates
)
for
tid
in
range
(
1
,
num_tensors
):
node_ops
=
nn
.
ModuleList
()
for
j
in
range
(
tid
):
inp
=
in_features
if
j
==
0
else
out_features
op_choices
=
OrderedDict
([(
key
,
cls
(
inp
,
out_features
))
for
key
,
cls
in
op_candidates
.
items
()])
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
__
{
j
}
_
{
tid
}
'
))
self
.
layers
.
append
(
node_ops
)
def
forward
(
self
,
inputs
):
tensors
=
[
inputs
]
for
layer
in
self
.
layers
:
current_tensor
=
[]
for
i
,
op
in
enumerate
(
layer
):
current_tensor
.
append
(
op
(
tensors
[
i
]))
current_tensor
=
torch
.
sum
(
torch
.
stack
(
current_tensor
),
0
)
tensors
.
append
(
current_tensor
)
return
tensors
[
-
1
]
test/ut/retiarii/test_highlevel_apis.py
View file @
542a660d
...
@@ -493,6 +493,27 @@ class GraphIR(unittest.TestCase):
...
@@ -493,6 +493,27 @@ class GraphIR(unittest.TestCase):
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
def
test_nasbench201_cell
(
self
):
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
NasBench201Cell
([
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
),
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
,
bias
=
False
)
],
10
,
16
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
16
]))
class
Python
(
GraphIR
):
class
Python
(
GraphIR
):
def
_get_converted_pytorch_model
(
self
,
model_ir
):
def
_get_converted_pytorch_model
(
self
,
model_ir
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment