Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9f32a06f
Unverified
Commit
9f32a06f
authored
Jul 12, 2021
by
Yuge Zhang
Committed by
GitHub
Jul 12, 2021
Browse files
[Retiarii] NAS-Bench-101 (#3871)
parent
dde4d862
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
738 additions
and
29 deletions
+738
-29
examples/nas/multi-trial/nasbench101/base_ops.py
examples/nas/multi-trial/nasbench101/base_ops.py
+51
-0
examples/nas/multi-trial/nasbench101/network.py
examples/nas/multi-trial/nasbench101/network.py
+173
-0
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+7
-3
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+33
-12
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+4
-2
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+27
-6
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+390
-0
nni/retiarii/nn/pytorch/utils.py
nni/retiarii/nn/pytorch/utils.py
+14
-2
nni/retiarii/strategy/bruteforce.py
nni/retiarii/strategy/bruteforce.py
+5
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+6
-1
test/.gitignore
test/.gitignore
+1
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+27
-1
No files found.
examples/nas/multi-trial/nasbench101/base_ops.py
0 → 100644
View file @
9f32a06f
import
math
import
torch.nn
as
nn
def
truncated_normal_
(
tensor
,
mean
=
0
,
std
=
1
):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size
=
tensor
.
shape
tmp
=
tensor
.
new_empty
(
size
+
(
4
,)).
normal_
()
valid
=
(
tmp
<
2
)
&
(
tmp
>
-
2
)
ind
=
valid
.
max
(
-
1
,
keepdim
=
True
)[
1
]
tensor
.
data
.
copy_
(
tmp
.
gather
(
-
1
,
ind
).
squeeze
(
-
1
))
tensor
.
data
.
mul_
(
std
).
add_
(
mean
)
class
ConvBnRelu
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
):
super
(
ConvBnRelu
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
conv_bn_relu
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
,
stride
,
padding
,
bias
=
False
),
nn
.
BatchNorm2d
(
out_channels
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
fan_in
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
in_channels
truncated_normal_
(
m
.
weight
.
data
,
mean
=
0.
,
std
=
math
.
sqrt
(
1.
/
fan_in
))
if
isinstance
(
m
,
nn
.
BatchNorm2d
):
m
.
weight
.
data
.
fill_
(
1
)
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
return
self
.
conv_bn_relu
(
x
)
class
Conv3x3BnRelu
(
ConvBnRelu
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
Conv3x3BnRelu
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
class
Conv1x1BnRelu
(
ConvBnRelu
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
Conv1x1BnRelu
,
self
).
__init__
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
Projection
=
Conv1x1BnRelu
examples/nas/multi-trial/nasbench101/network.py
0 → 100644
View file @
9f32a06f
import
click
import
nni
import
nni.retiarii.evaluator.pytorch.lightning
as
pl
import
torch.nn
as
nn
import
torchmetrics
from
nni.retiarii
import
model_wrapper
,
serialize
,
serialize_cls
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.nn.pytorch
import
NasBench101Cell
from
nni.retiarii.strategy
import
Random
from
pytorch_lightning.callbacks
import
LearningRateMonitor
from
timm.optim
import
RMSpropTF
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
base_ops
import
Conv3x3BnRelu
,
Conv1x1BnRelu
,
Projection
@
model_wrapper
class
NasBench101
(
nn
.
Module
):
def
__init__
(
self
,
stem_out_channels
:
int
=
128
,
num_stacks
:
int
=
3
,
num_modules_per_stack
:
int
=
3
,
max_num_vertices
:
int
=
7
,
max_num_edges
:
int
=
9
,
num_labels
:
int
=
10
,
bn_eps
:
float
=
1e-5
,
bn_momentum
:
float
=
0.003
):
super
().
__init__
()
op_candidates
=
{
'conv3x3'
:
lambda
num_features
:
Conv3x3BnRelu
(
num_features
,
num_features
),
'conv1x1'
:
lambda
num_features
:
Conv1x1BnRelu
(
num_features
,
num_features
),
'maxpool'
:
lambda
num_features
:
nn
.
MaxPool2d
(
3
,
1
,
1
)
}
# initial stem convolution
self
.
stem_conv
=
Conv3x3BnRelu
(
3
,
stem_out_channels
)
layers
=
[]
in_channels
=
out_channels
=
stem_out_channels
for
stack_num
in
range
(
num_stacks
):
if
stack_num
>
0
:
downsample
=
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
)
layers
.
append
(
downsample
)
out_channels
*=
2
for
_
in
range
(
num_modules_per_stack
):
cell
=
NasBench101Cell
(
op_candidates
,
in_channels
,
out_channels
,
lambda
cin
,
cout
:
Projection
(
cin
,
cout
),
max_num_vertices
,
max_num_edges
,
label
=
'cell'
)
layers
.
append
(
cell
)
in_channels
=
out_channels
self
.
features
=
nn
.
ModuleList
(
layers
)
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
classifier
=
nn
.
Linear
(
out_channels
,
num_labels
)
for
module
in
self
.
modules
():
if
isinstance
(
module
,
nn
.
BatchNorm2d
):
module
.
eps
=
bn_eps
module
.
momentum
=
bn_momentum
def
forward
(
self
,
x
):
bs
=
x
.
size
(
0
)
out
=
self
.
stem_conv
(
x
)
for
layer
in
self
.
features
:
out
=
layer
(
out
)
out
=
self
.
gap
(
out
).
view
(
bs
,
-
1
)
out
=
self
.
classifier
(
out
)
return
out
def
reset_parameters
(
self
):
for
module
in
self
.
modules
():
if
isinstance
(
module
,
nn
.
BatchNorm2d
):
module
.
eps
=
self
.
config
.
bn_eps
module
.
momentum
=
self
.
config
.
bn_momentum
class
AccuracyWithLogits
(
torchmetrics
.
Accuracy
):
def
update
(
self
,
pred
,
target
):
return
super
().
update
(
nn
.
functional
.
softmax
(
pred
),
target
)
@
serialize_cls
class
NasBench101TrainingModule
(
pl
.
LightningModule
):
def
__init__
(
self
,
max_epochs
=
108
,
learning_rate
=
0.1
,
weight_decay
=
1e-4
):
super
().
__init__
()
self
.
save_hyperparameters
(
'learning_rate'
,
'weight_decay'
,
'max_epochs'
)
self
.
criterion
=
nn
.
CrossEntropyLoss
()
self
.
accuracy
=
AccuracyWithLogits
()
def
forward
(
self
,
x
):
y_hat
=
self
.
model
(
x
)
return
y_hat
def
training_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
y_hat
=
self
(
x
)
loss
=
self
.
criterion
(
y_hat
,
y
)
self
.
log
(
'train_loss'
,
loss
,
prog_bar
=
True
)
self
.
log
(
'train_accuracy'
,
self
.
accuracy
(
y_hat
,
y
),
prog_bar
=
True
)
return
loss
def
validation_step
(
self
,
batch
,
batch_idx
):
x
,
y
=
batch
y_hat
=
self
(
x
)
self
.
log
(
'val_loss'
,
self
.
criterion
(
y_hat
,
y
),
prog_bar
=
True
)
self
.
log
(
'val_accuracy'
,
self
.
accuracy
(
y_hat
,
y
),
prog_bar
=
True
)
def
configure_optimizers
(
self
):
optimizer
=
RMSpropTF
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
,
momentum
=
0.9
,
alpha
=
0.9
,
eps
=
1.0
)
return
{
'optimizer'
:
optimizer
,
'scheduler'
:
CosineAnnealingLR
(
optimizer
,
self
.
hparams
.
max_epochs
)
}
def
on_validation_epoch_end
(
self
):
nni
.
report_intermediate_result
(
self
.
trainer
.
callback_metrics
[
'val_accuracy'
].
item
())
def
teardown
(
self
,
stage
):
if
stage
==
'fit'
:
nni
.
report_final_result
(
self
.
trainer
.
callback_metrics
[
'val_accuracy'
].
item
())
@
click
.
command
()
@
click
.
option
(
'--epochs'
,
default
=
108
,
help
=
'Training length.'
)
@
click
.
option
(
'--batch_size'
,
default
=
256
,
help
=
'Batch size.'
)
@
click
.
option
(
'--port'
,
default
=
8081
,
help
=
'On which port the experiment is run.'
)
def
_multi_trial_test
(
epochs
,
batch_size
,
port
):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf
=
[
transforms
.
RandomCrop
(
32
,
padding
=
4
),
transforms
.
RandomHorizontalFlip
()
]
normalize
=
[
transforms
.
ToTensor
(),
transforms
.
Normalize
([
0.49139968
,
0.48215827
,
0.44653124
],
[
0.24703233
,
0.24348505
,
0.26158768
])
]
train_dataset
=
serialize
(
CIFAR10
,
'data'
,
train
=
True
,
download
=
True
,
transform
=
transforms
.
Compose
(
transf
+
normalize
))
test_dataset
=
serialize
(
CIFAR10
,
'data'
,
train
=
False
,
transform
=
transforms
.
Compose
(
normalize
))
# specify training hyper-parameters
training_module
=
NasBench101TrainingModule
(
max_epochs
=
epochs
)
# FIXME: need to fix a bug in serializer for this to work
# lr_monitor = serialize(LearningRateMonitor, logging_interval='step')
trainer
=
pl
.
Trainer
(
max_epochs
=
epochs
,
gpus
=
1
)
lightning
=
pl
.
Lightning
(
lightning_module
=
training_module
,
trainer
=
trainer
,
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
batch_size
),
)
strategy
=
Random
()
model
=
NasBench101
()
exp
=
RetiariiExperiment
(
model
,
lightning
,
[],
strategy
)
exp_config
=
RetiariiExeConfig
(
'local'
)
exp_config
.
trial_concurrency
=
2
exp_config
.
max_trial_number
=
20
exp_config
.
trial_gpu_number
=
1
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
port
)
if
__name__
==
'__main__'
:
_multi_trial_test
()
nni/retiarii/mutator.py
View file @
9f32a06f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
)
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
,
Tuple
)
from
.graph
import
Model
,
Mutation
,
ModelStatus
__all__
=
[
'Sampler'
,
'Mutator'
]
__all__
=
[
'Sampler'
,
'Mutator'
,
'InvalidMutation'
]
Choice
=
Any
...
...
@@ -77,7 +77,7 @@ class Mutator:
self
.
_cur_choice_idx
=
None
return
copy
def
dry_run
(
self
,
model
:
Model
)
->
List
[
List
[
Choice
]]:
def
dry_run
(
self
,
model
:
Model
)
->
Tuple
[
List
[
List
[
Choice
]]
,
Model
]
:
"""
Dry run mutator on a model to collect choice candidates.
If you invoke this method multiple times on same or different models,
...
...
@@ -115,3 +115,7 @@ class _RecorderSampler(Sampler):
def
choice
(
self
,
candidates
:
List
[
Choice
],
*
args
)
->
Choice
:
self
.
recorded_candidates
.
append
(
candidates
)
return
candidates
[
0
]
class
InvalidMutation
(
Exception
):
pass
nni/retiarii/nn/pytorch/api.py
View file @
9f32a06f
...
...
@@ -3,13 +3,13 @@
import
copy
import
warnings
from
collections
import
OrderedDict
from
typing
import
Any
,
List
,
Union
,
Dict
,
Optional
import
torch
import
torch.nn
as
nn
from
...serializer
import
Translatable
,
basic_unit
from
...utils
import
NoContextError
from
.utils
import
generate_new_label
,
get_fixed_value
...
...
@@ -26,6 +26,8 @@ class LayerChoice(nn.Module):
----------
candidates : list of nn.Module or OrderedDict
A module list to be selected from.
prior : list of float
Prior distribution used in random sampling.
label : str
Identifier of the layer choice.
...
...
@@ -55,17 +57,21 @@ class LayerChoice(nn.Module):
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def
__new__
(
cls
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
# FIXME: prior is designed but not supported yet
def
__new__
(
cls
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
try
:
chosen
=
get_fixed_value
(
label
)
if
isinstance
(
candidates
,
list
):
return
candidates
[
int
(
chosen
)]
else
:
return
candidates
[
chosen
]
except
Assertion
Error
:
except
NoContext
Error
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
LayerChoice
,
self
).
__init__
()
if
'key'
in
kwargs
:
warnings
.
warn
(
f
'"key" is deprecated. Assuming label.'
)
...
...
@@ -75,10 +81,12 @@ class LayerChoice(nn.Module):
if
'reduction'
in
kwargs
:
warnings
.
warn
(
f
'"reduction" is deprecated. Ignoring...'
)
self
.
candidates
=
candidates
self
.
prior
=
prior
or
[
1
/
len
(
candidates
)
for
_
in
range
(
len
(
candidates
))]
assert
abs
(
sum
(
self
.
prior
)
-
1
)
<
1e-5
,
'Sum of prior distribution is not 1.'
self
.
_label
=
generate_new_label
(
label
)
self
.
names
=
[]
if
isinstance
(
candidates
,
OrderedD
ict
):
if
isinstance
(
candidates
,
d
ict
):
for
name
,
module
in
candidates
.
items
():
assert
name
not
in
[
"length"
,
"reduction"
,
"return_mask"
,
"_key"
,
"key"
,
"names"
],
\
"Please don't use a reserved name '{}' for your module."
.
format
(
name
)
...
...
@@ -169,17 +177,23 @@ class InputChoice(nn.Module):
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``.
prior : list of float
Prior distribution used in random sampling.
label : str
Identifier of the input choice.
"""
def
__new__
(
cls
,
n_candidates
:
int
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
def
__new__
(
cls
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
=
1
,
reduction
:
str
=
'sum'
,
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
try
:
return
ChosenInputs
(
get_fixed_value
(
label
),
reduction
=
reduction
)
except
Assertion
Error
:
except
NoContext
Error
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
=
1
,
reduction
:
str
=
'sum'
,
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
InputChoice
,
self
).
__init__
()
if
'key'
in
kwargs
:
warnings
.
warn
(
f
'"key" is deprecated. Assuming label.'
)
...
...
@@ -191,6 +205,7 @@ class InputChoice(nn.Module):
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
prior
=
prior
or
[
1
/
n_candidates
for
_
in
range
(
n_candidates
)]
assert
self
.
reduction
in
[
'mean'
,
'concat'
,
'sum'
,
'none'
]
self
.
_label
=
generate_new_label
(
label
)
...
...
@@ -277,19 +292,25 @@ class ValueChoice(Translatable, nn.Module):
----------
candidates : list
List of values to choose from.
prior : list of float
Prior distribution to sample from.
label : str
Identifier of the value choice.
"""
def
__new__
(
cls
,
candidates
:
List
[
Any
],
label
:
Optional
[
str
]
=
None
):
# FIXME: prior is designed but not supported yet
def
__new__
(
cls
,
candidates
:
List
[
Any
],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
):
try
:
return
get_fixed_value
(
label
)
except
Assertion
Error
:
except
NoContext
Error
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
candidates
:
List
[
Any
],
label
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
candidates
:
List
[
Any
],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
candidates
=
candidates
self
.
prior
=
prior
or
[
1
/
len
(
candidates
)
for
_
in
range
(
len
(
candidates
))]
assert
abs
(
sum
(
self
.
prior
)
-
1
)
<
1e-5
,
'Sum of prior distribution is not 1.'
self
.
_label
=
generate_new_label
(
label
)
self
.
_accessor
=
[]
...
...
@@ -323,7 +344,7 @@ class ValueChoice(Translatable, nn.Module):
return
self
def
__deepcopy__
(
self
,
memo
):
new_item
=
ValueChoice
(
self
.
candidates
,
self
.
label
)
new_item
=
ValueChoice
(
self
.
candidates
,
label
=
self
.
label
)
new_item
.
_accessor
=
[
*
self
.
_accessor
]
return
new_item
...
...
nni/retiarii/nn/pytorch/component.py
View file @
9f32a06f
...
...
@@ -7,10 +7,12 @@ import torch.nn as nn
from
.api
import
LayerChoice
,
InputChoice
from
.nn
import
ModuleList
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.utils
import
generate_new_label
,
get_fixed_value
from
...utils
import
NoContextError
__all__
=
[
'Repeat'
,
'Cell'
]
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
]
class
Repeat
(
nn
.
Module
):
...
...
@@ -33,7 +35,7 @@ class Repeat(nn.Module):
try
:
repeat
=
get_fixed_value
(
label
)
return
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
except
Assertion
Error
:
except
NoContext
Error
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
9f32a06f
...
...
@@ -9,7 +9,7 @@ import torch.nn as nn
from
...mutator
import
Mutator
from
...graph
import
Cell
,
Graph
,
Model
,
ModelStatus
,
Node
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
Placeholder
from
.component
import
Repeat
from
.component
import
Repeat
,
NasBench101Cell
,
NasBench101Mutator
from
...utils
import
uid
...
...
@@ -47,7 +47,12 @@ class InputChoiceMutator(Mutator):
n_candidates
=
self
.
nodes
[
0
].
operation
.
parameters
[
'n_candidates'
]
n_chosen
=
self
.
nodes
[
0
].
operation
.
parameters
[
'n_chosen'
]
candidates
=
list
(
range
(
n_candidates
))
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
n_chosen
)]
if
n_chosen
is
None
:
chosen
=
[
i
for
i
in
candidates
if
self
.
choice
([
False
,
True
])]
# FIXME This is a hack to make choice align with the previous format
self
.
_cur_samples
=
chosen
else
:
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
n_chosen
)]
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.ChosenInputs'
,
...
...
@@ -199,8 +204,15 @@ class ManyChooseManyMutator(Mutator):
def
mutate
(
self
,
model
:
Model
):
# this mutate does not have any effect, but it is recorded in the mutation history
for
node
in
model
.
get_nodes_by_label
(
self
.
label
):
for
_
in
range
(
self
.
number_of_chosen
(
node
)):
self
.
choice
(
self
.
candidates
(
node
))
n_chosen
=
self
.
number_of_chosen
(
node
)
if
n_chosen
is
None
:
candidates
=
[
i
for
i
in
self
.
candidates
(
node
)
if
self
.
choice
([
False
,
True
])]
# FIXME This is a hack to make choice align with the previous format
# For example, it will convert [False, True, True] into [1, 2].
self
.
_cur_samples
=
candidates
else
:
for
_
in
range
(
n_chosen
):
self
.
choice
(
self
.
candidates
(
node
))
break
...
...
@@ -242,6 +254,11 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
'candidates'
:
list
(
range
(
module
.
min_depth
,
module
.
max_depth
+
1
))
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
NasBench101Cell
):
node
=
graph
.
add_node
(
name
,
'NasBench101Cell'
,
{
'max_num_edges'
:
module
.
max_num_edges
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
Placeholder
):
raise
NotImplementedError
(
'Placeholder is not supported in python execution mode.'
)
...
...
@@ -250,13 +267,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
return
model
,
None
mutators
=
[]
mutators_final
=
[]
for
nodes
in
_group_by_label_and_type
(
graph
.
hidden_nodes
):
assert
_is_all_equal
(
map
(
lambda
n
:
n
.
operation
.
type
,
nodes
)),
\
f
'Node with label "
{
nodes
[
0
].
label
}
" does not all have the same type.'
assert
_is_all_equal
(
map
(
lambda
n
:
n
.
operation
.
parameters
,
nodes
)),
\
f
'Node with label "
{
nodes
[
0
].
label
}
" does not agree on parameters.'
mutators
.
append
(
ManyChooseManyMutator
(
nodes
[
0
].
label
))
return
model
,
mutators
if
nodes
[
0
].
operation
.
type
==
'NasBench101Cell'
:
mutators_final
.
append
(
NasBench101Mutator
(
nodes
[
0
].
label
))
else
:
mutators
.
append
(
ManyChooseManyMutator
(
nodes
[
0
].
label
))
return
model
,
mutators
+
mutators_final
# utility functions
...
...
nni/retiarii/nn/pytorch/nasbench101.py
0 → 100644
View file @
9f32a06f
import
logging
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Optional
,
Union
,
Dict
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
.api
import
InputChoice
,
ValueChoice
,
LayerChoice
from
.utils
import
generate_new_label
,
get_fixed_dict
from
...mutator
import
InvalidMutation
,
Mutator
from
...graph
import
Model
from
...utils
import
NoContextError
_logger
=
logging
.
getLogger
(
__name__
)
def
compute_vertex_channels
(
input_channels
,
output_channels
,
matrix
):
"""
This is (almost) copied from the original NAS-Bench-101 implementation.
Computes the number of channels at every vertex.
Given the input channels and output channels, this calculates the number of channels at each interior vertex.
Interior vertices have the same number of channels as the max of the channels of the vertices it feeds into.
The output channels are divided amongst the vertices that are directly connected to it.
When the division is not even, some vertices may receive an extra channel to compensate.
Parameters
----------
in_channels : int
input channels count.
output_channels : int
output channel count.
matrix : np.ndarray
adjacency matrix for the module (pruned by model_spec).
Returns
-------
list of int
list of channel counts, in order of the vertices.
"""
num_vertices
=
np
.
shape
(
matrix
)[
0
]
vertex_channels
=
[
0
]
*
num_vertices
vertex_channels
[
0
]
=
input_channels
vertex_channels
[
num_vertices
-
1
]
=
output_channels
if
num_vertices
==
2
:
# Edge case where module only has input and output vertices
return
vertex_channels
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
in_degree
=
np
.
sum
(
matrix
[
1
:],
axis
=
0
)
interior_channels
=
output_channels
//
in_degree
[
num_vertices
-
1
]
correction
=
output_channels
%
in_degree
[
num_vertices
-
1
]
# Remainder to add
# Set channels of vertices that flow directly to output
for
v
in
range
(
1
,
num_vertices
-
1
):
if
matrix
[
v
,
num_vertices
-
1
]:
vertex_channels
[
v
]
=
interior_channels
if
correction
:
vertex_channels
[
v
]
+=
1
correction
-=
1
# Set channels for all other vertices to the max of the out edges, going backwards.
# (num_vertices - 2) index skipped because it only connects to output.
for
v
in
range
(
num_vertices
-
3
,
0
,
-
1
):
if
not
matrix
[
v
,
num_vertices
-
1
]:
for
dst
in
range
(
v
+
1
,
num_vertices
-
1
):
if
matrix
[
v
,
dst
]:
vertex_channels
[
v
]
=
max
(
vertex_channels
[
v
],
vertex_channels
[
dst
])
assert
vertex_channels
[
v
]
>
0
_logger
.
debug
(
'vertex_channels: %s'
,
str
(
vertex_channels
))
# Sanity check, verify that channels never increase and final channels add up.
final_fan_in
=
0
for
v
in
range
(
1
,
num_vertices
-
1
):
if
matrix
[
v
,
num_vertices
-
1
]:
final_fan_in
+=
vertex_channels
[
v
]
for
dst
in
range
(
v
+
1
,
num_vertices
-
1
):
if
matrix
[
v
,
dst
]:
assert
vertex_channels
[
v
]
>=
vertex_channels
[
dst
]
assert
final_fan_in
==
output_channels
or
num_vertices
==
2
# num_vertices == 2 means only input/output nodes, so 0 fan-in
return
vertex_channels
def
prune
(
matrix
,
ops
):
"""
Prune the extraneous parts of the graph.
General procedure:
1. Remove parts of graph not connected to input.
2. Remove parts of graph not connected to output.
3. Reorder the vertices so that they are consecutive after steps 1 and 2.
These 3 steps can be combined by deleting the rows and columns of the
vertices that are not reachable from both the input and output (in reverse).
"""
num_vertices
=
np
.
shape
(
matrix
)[
0
]
# calculate the connection matrix within V number of steps.
connections
=
np
.
linalg
.
matrix_power
(
matrix
+
np
.
eye
(
num_vertices
),
num_vertices
)
visited_from_input
=
set
([
i
for
i
in
range
(
num_vertices
)
if
connections
[
0
,
i
]])
visited_from_output
=
set
([
i
for
i
in
range
(
num_vertices
)
if
connections
[
i
,
-
1
]])
# Any vertex that isn't connected to both input and output is extraneous to the computation graph.
extraneous
=
set
(
range
(
num_vertices
)).
difference
(
visited_from_input
.
intersection
(
visited_from_output
))
if
len
(
extraneous
)
>
num_vertices
-
2
:
raise
InvalidMutation
(
'Non-extraneous graph is less than 2 vertices, '
'the input is not connected to the output and the spec is invalid.'
)
matrix
=
np
.
delete
(
matrix
,
list
(
extraneous
),
axis
=
0
)
matrix
=
np
.
delete
(
matrix
,
list
(
extraneous
),
axis
=
1
)
for
index
in
sorted
(
extraneous
,
reverse
=
True
):
del
ops
[
index
]
return
matrix
,
ops
def
truncate
(
inputs
,
channels
):
input_channels
=
inputs
.
size
(
1
)
if
input_channels
<
channels
:
raise
ValueError
(
'input channel < output channels for truncate'
)
elif
input_channels
==
channels
:
return
inputs
# No truncation necessary
else
:
# Truncation should only be necessary when channel division leads to
# vertices with +1 channels. The input vertex should always be projected to
# the minimum channel count.
assert
input_channels
-
channels
==
1
return
inputs
[:,
:
channels
]
class
_NasBench101CellFixed
(
nn
.
Module
):
"""
The fixed version of NAS-Bench-101 Cell, used in python-version execution engine.
"""
def
__init__
(
self
,
operations
:
List
[
Callable
[[
int
],
nn
.
Module
]],
adjacency_list
:
List
[
List
[
int
]],
in_features
:
int
,
out_features
:
int
,
num_nodes
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
]):
super
().
__init__
()
assert
num_nodes
==
len
(
operations
)
+
2
==
len
(
adjacency_list
)
+
1
self
.
operations
=
[
'IN'
]
+
operations
+
[
'OUT'
]
# add psuedo nodes
self
.
connection_matrix
=
self
.
build_connection_matrix
(
adjacency_list
,
num_nodes
)
del
num_nodes
# raw number of nodes is no longer used
self
.
connection_matrix
,
self
.
operations
=
prune
(
self
.
connection_matrix
,
self
.
operations
)
self
.
hidden_features
=
compute_vertex_channels
(
in_features
,
out_features
,
self
.
connection_matrix
)
self
.
num_nodes
=
len
(
self
.
connection_matrix
)
self
.
in_features
=
in_features
self
.
out_features
=
out_features
_logger
.
info
(
'Prund number of nodes: %d'
,
self
.
num_nodes
)
_logger
.
info
(
'Pruned connection matrix: %s'
,
str
(
self
.
connection_matrix
))
self
.
projections
=
nn
.
ModuleList
([
nn
.
Identity
()])
self
.
ops
=
nn
.
ModuleList
([
nn
.
Identity
()])
for
i
in
range
(
1
,
self
.
num_nodes
):
self
.
projections
.
append
(
projection
(
in_features
,
self
.
hidden_features
[
i
]))
for
i
in
range
(
1
,
self
.
num_nodes
-
1
):
self
.
ops
.
append
(
operations
[
i
-
1
](
self
.
hidden_features
[
i
]))
@
staticmethod
def
build_connection_matrix
(
adjacency_list
,
num_nodes
):
adjacency_list
=
[[]]
+
adjacency_list
# add adjacency for first node
connections
=
np
.
zeros
((
num_nodes
,
num_nodes
),
dtype
=
'int'
)
for
i
,
lst
in
enumerate
(
adjacency_list
):
assert
all
([
0
<=
k
<
i
for
k
in
lst
])
for
k
in
lst
:
connections
[
k
,
i
]
=
1
return
connections
def
forward
(
self
,
inputs
):
tensors
=
[
inputs
]
for
t
in
range
(
1
,
self
.
num_nodes
-
1
):
# Create interior connections, truncating if necessary
add_in
=
[
truncate
(
tensors
[
src
],
self
.
hidden_features
[
t
])
for
src
in
range
(
1
,
t
)
if
self
.
connection_matrix
[
src
,
t
]]
# Create add connection from projected input
if
self
.
connection_matrix
[
0
,
t
]:
add_in
.
append
(
self
.
projections
[
t
](
tensors
[
0
]))
if
len
(
add_in
)
==
1
:
vertex_input
=
add_in
[
0
]
else
:
vertex_input
=
sum
(
add_in
)
# Perform op at vertex t
vertex_out
=
self
.
ops
[
t
](
vertex_input
)
tensors
.
append
(
vertex_out
)
# Construct final output tensor by concating all fan-in and adding input.
if
np
.
sum
(
self
.
connection_matrix
[:,
-
1
])
==
1
:
src
=
np
.
where
(
self
.
connection_matrix
[:,
-
1
]
==
1
)[
0
][
0
]
return
self
.
projections
[
-
1
](
tensors
[
0
])
if
src
==
0
else
tensors
[
src
]
outputs
=
torch
.
cat
([
tensors
[
src
]
for
src
in
range
(
1
,
self
.
num_nodes
-
1
)
if
self
.
connection_matrix
[
src
,
-
1
]],
1
)
if
self
.
connection_matrix
[
0
,
-
1
]:
outputs
+=
self
.
projections
[
-
1
](
tensors
[
0
])
assert
outputs
.
size
(
1
)
==
self
.
out_features
return
outputs
class
NasBench101Cell
(
nn
.
Module
):
"""
Cell structure that is proposed in NAS-Bench-101 [nasbench101]_ .
This cell is usually used in evaluation of NAS algorithms because there is a ``comprehensive analysis'' of this search space
available, which includes a full architecture-dataset that ``maps 423k unique architectures to metrics
including run time and accuracy''. You can also use the space in your own space design, in which scenario it should be possible
to leverage results in the benchmark to narrow the huge space down to a few efficient architectures.
The space of this cell architecture consists of all possible directed acyclic graphs on no more than ``max_num_nodes`` nodes,
where each possible node (other than IN and OUT) has one of ``op_candidates``, representing the corresponding operation.
Edges connecting the nodes can be no more than ``max_num_edges``.
To align with the paper settings, two vertices specially labeled as operation IN and OUT, are also counted into
``max_num_nodes`` in our implementaion, the default value of ``max_num_nodes`` is 7 and ``max_num_edges`` is 9.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be `[N, C_{out}, *]`. The shape
of each hidden nodes will be first automatically computed, depending on the cell structure. Each of the ``op_candidates``
should be a callable that accepts computed ``num_features`` and returns a ``Module``. For example,
.. code-block:: python
def conv_bn_relu(num_features):
return nn.Sequential(
nn.Conv2d(num_features, num_features, 1),
nn.BatchNorm2d(num_features),
nn.ReLU()
)
The output of each node is the sum of its input node feed into its operation, except for the last node (output node),
which is the concatenation of its input *hidden* nodes, adding the *IN* node (if IN and OUT are connected).
When input tensor is added with any other tensor, there could be shape mismatch. Therefore, a projection transformation
is needed to transform the input tensor. In paper, this is simply a Conv1x1 followed by BN and ReLU. The ``projection``
parameters accepts ``in_features`` and ``out_features``, returns a ``Module``. This parameter has no default value,
as we hold no assumption that users are dealing with images. An example for this parameter is,
.. code-block:: python
def projection_fn(in_features, out_features):
return nn.Conv2d(in_features, out_features, 1)
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts number of feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
projection : callable
Projection module that is used to preprocess the input tensor of the whole cell.
A callable that accept input feature and output feature, returning nn.Module.
max_num_nodes : int
Maximum number of nodes in the cell, input and output included. At least 2. Default: 7.
max_num_edges : int
Maximum number of edges in the cell. Default: 9.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
References
----------
.. [nasbench101] Ying, Chris, et al. "Nas-bench-101: Towards reproducible neural architecture search."
International Conference on Machine Learning. PMLR, 2019.
"""
@
staticmethod
def
_make_dict
(
x
):
if
isinstance
(
x
,
list
):
return
OrderedDict
([(
str
(
i
),
t
)
for
i
,
t
in
enumerate
(
x
)])
return
OrderedDict
(
x
)
def
__new__
(
cls
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
max_num_nodes
:
int
=
7
,
max_num_edges
:
int
=
9
,
label
:
Optional
[
str
]
=
None
):
def
make_list
(
x
):
return
x
if
isinstance
(
x
,
list
)
else
[
x
]
try
:
label
,
selected
=
get_fixed_dict
(
label
)
op_candidates
=
cls
.
_make_dict
(
op_candidates
)
num_nodes
=
selected
[
f
'
{
label
}
/num_nodes'
]
adjacency_list
=
[
make_list
(
selected
[
f
'
{
label
}
/input_
{
i
}
'
])
for
i
in
range
(
1
,
num_nodes
)]
if
sum
([
len
(
e
)
for
e
in
adjacency_list
])
>
max_num_edges
:
raise
InvalidMutation
(
f
'Expected
{
max_num_edges
}
edges, found:
{
adjacency_list
}
'
)
return
_NasBench101CellFixed
(
[
op_candidates
[
selected
[
f
'
{
label
}
/op_
{
i
}
'
]]
for
i
in
range
(
1
,
num_nodes
-
1
)],
adjacency_list
,
in_features
,
out_features
,
num_nodes
,
projection
)
except
NoContextError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
max_num_nodes
:
int
=
7
,
max_num_edges
:
int
=
9
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
num_vertices_prior
=
[
2
**
i
for
i
in
range
(
2
,
max_num_nodes
+
1
)]
num_vertices_prior
=
(
np
.
array
(
num_vertices_prior
)
/
sum
(
num_vertices_prior
)).
tolist
()
self
.
num_nodes
=
ValueChoice
(
list
(
range
(
2
,
max_num_nodes
+
1
)),
prior
=
num_vertices_prior
,
label
=
f
'
{
self
.
_label
}
/num_nodes'
)
self
.
max_num_nodes
=
max_num_nodes
self
.
max_num_edges
=
max_num_edges
op_candidates
=
self
.
_make_dict
(
op_candidates
)
# this is only for input validation and instantiating enough layer choice and input choice
self
.
hidden_features
=
out_features
self
.
projections
=
nn
.
ModuleList
([
nn
.
Identity
()])
self
.
ops
=
nn
.
ModuleList
([
nn
.
Identity
()])
self
.
inputs
=
nn
.
ModuleList
([
nn
.
Identity
()])
for
_
in
range
(
1
,
max_num_nodes
):
self
.
projections
.
append
(
projection
(
in_features
,
self
.
hidden_features
))
for
i
in
range
(
1
,
max_num_nodes
):
if
i
<
max_num_nodes
-
1
:
self
.
ops
.
append
(
LayerChoice
(
OrderedDict
([(
k
,
op
(
self
.
hidden_features
))
for
k
,
op
in
op_candidates
.
items
()]),
label
=
f
'
{
self
.
_label
}
/op_
{
i
}
'
))
self
.
inputs
.
append
(
InputChoice
(
i
,
None
,
label
=
f
'
{
self
.
_label
}
/input_
{
i
}
'
))
@
property
def
label
(
self
):
return
self
.
_label
def
forward
(
self
,
x
):
# This is a dummy forward and actually not used
tensors
=
[
x
]
for
i
in
range
(
1
,
self
.
max_num_nodes
):
node_input
=
self
.
inputs
[
i
]([
self
.
projections
[
i
](
tensors
[
0
])]
+
[
t
for
t
in
tensors
[
1
:]])
if
i
<
self
.
max_num_nodes
-
1
:
node_output
=
self
.
ops
[
i
](
node_input
)
else
:
node_output
=
node_input
tensors
.
append
(
node_output
)
return
tensors
[
-
1
]
class
NasBench101Mutator
(
Mutator
):
# for validation purposes
# for python execution engine
def
__init__
(
self
,
label
:
Optional
[
str
]):
super
().
__init__
(
label
=
label
)
@
staticmethod
def
candidates
(
node
):
if
'n_candidates'
in
node
.
operation
.
parameters
:
return
list
(
range
(
node
.
operation
.
parameters
[
'n_candidates'
]))
else
:
return
node
.
operation
.
parameters
[
'candidates'
]
@
staticmethod
def
number_of_chosen
(
node
):
if
'n_chosen'
in
node
.
operation
.
parameters
:
return
node
.
operation
.
parameters
[
'n_chosen'
]
return
1
def
mutate
(
self
,
model
:
Model
):
for
node
in
model
.
get_nodes_by_label
(
self
.
label
):
max_num_edges
=
node
.
operation
.
parameters
[
'max_num_edges'
]
break
mutation_dict
=
{
mut
.
mutator
.
label
:
mut
.
samples
for
mut
in
model
.
history
}
num_nodes
=
mutation_dict
[
f
'
{
self
.
label
}
/num_nodes'
][
0
]
adjacency_list
=
[
mutation_dict
[
f
'
{
self
.
label
}
/input_
{
i
}
'
]
for
i
in
range
(
1
,
num_nodes
)]
if
sum
([
len
(
e
)
for
e
in
adjacency_list
])
>
max_num_edges
:
raise
InvalidMutation
(
f
'Expected
{
max_num_edges
}
edges, found:
{
adjacency_list
}
'
)
matrix
=
_NasBench101CellFixed
.
build_connection_matrix
(
adjacency_list
,
num_nodes
)
prune
(
matrix
,
[
None
]
*
len
(
matrix
))
# dummy ops, possible to raise InvalidMutation inside
def
dry_run
(
self
,
model
):
return
[],
model
nni/retiarii/nn/pytorch/utils.py
View file @
9f32a06f
from
typing
import
Optional
from
typing
import
Any
,
Optional
,
Tuple
from
...utils
import
uid
,
get_current_context
...
...
@@ -9,9 +9,21 @@ def generate_new_label(label: Optional[str]):
return
label
def
get_fixed_value
(
label
:
str
):
def
get_fixed_value
(
label
:
str
)
->
Any
:
ret
=
get_current_context
(
'fixed'
)
try
:
return
ret
[
generate_new_label
(
label
)]
except
KeyError
:
raise
KeyError
(
f
'Fixed context with
{
label
}
not found. Existing values are:
{
ret
}
'
)
def
get_fixed_dict
(
label_prefix
:
str
)
->
Tuple
[
str
,
Any
]:
ret
=
get_current_context
(
'fixed'
)
try
:
label_prefix
=
generate_new_label
(
label_prefix
)
ret
=
{
k
:
v
for
k
,
v
in
ret
.
items
()
if
k
.
startswith
(
label_prefix
+
'/'
)}
if
not
ret
:
raise
KeyError
return
label_prefix
,
ret
except
KeyError
:
raise
KeyError
(
f
'Fixed context with prefix
{
label_prefix
}
not found. Existing values are:
{
ret
}
'
)
nni/retiarii/strategy/bruteforce.py
View file @
9f32a06f
...
...
@@ -8,7 +8,7 @@ import random
import
time
from
typing
import
Any
,
Dict
,
List
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
budget_exhausted
from
..
import
InvalidMutation
,
Sampler
,
submit_models
,
query_available_resources
,
budget_exhausted
from
.base
import
BaseStrategy
from
.utils
import
dry_run_for_search_space
,
get_targeted_model
...
...
@@ -121,4 +121,7 @@ class Random(BaseStrategy):
if
budget_exhausted
():
return
time
.
sleep
(
self
.
_polling_interval
)
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
try
:
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
except
InvalidMutation
as
e
:
_logger
.
warning
(
f
'Invalid mutation:
{
e
}
. Skip.'
)
nni/retiarii/utils.py
View file @
9f32a06f
...
...
@@ -67,6 +67,10 @@ def get_importable_name(cls, relocate_module=False):
return
module_name
+
'.'
+
cls
.
__name__
class
NoContextError
(
Exception
):
pass
class
ContextStack
:
"""
This is to maintain a globally-accessible context envinronment that is visible to everywhere.
...
...
@@ -98,7 +102,8 @@ class ContextStack:
@
classmethod
def
top
(
cls
,
key
:
str
)
->
Any
:
assert
cls
.
_stack
[
key
],
'Context is empty.'
if
not
cls
.
_stack
[
key
]:
raise
NoContextError
(
'Context is empty.'
)
return
cls
.
_stack
[
key
][
-
1
]
...
...
test/.gitignore
View file @
9f32a06f
...
...
@@ -10,3 +10,4 @@ _generated_model
data
generated
lightning_logs
model.onnx
test/ut/retiarii/test_highlevel_apis.py
View file @
9f32a06f
...
...
@@ -5,7 +5,7 @@ from collections import Counter
import
nni.retiarii.nn.pytorch
as
nn
import
torch
import
torch.nn.functional
as
F
from
nni.retiarii
import
Sampler
,
basic_unit
from
nni.retiarii
import
InvalidMutation
,
Sampler
,
basic_unit
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.execution.python
import
_unpack_if_only_one
...
...
@@ -518,3 +518,29 @@ class Python(GraphIR):
@
unittest
.
skip
def
test_valuechoice_access_functional_expression
(
self
):
...
def
test_nasbench101_cell
(
self
):
# this is only supported in python engine for now.
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
NasBench101Cell
([
lambda
x
:
nn
.
Linear
(
x
,
x
),
lambda
x
:
nn
.
Linear
(
x
,
x
,
bias
=
False
)],
10
,
16
,
lambda
x
,
y
:
nn
.
Linear
(
x
,
y
),
max_num_nodes
=
5
,
max_num_edges
=
7
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
succeeded
=
0
sampler
=
RandomSampler
()
while
succeeded
<=
10
:
try
:
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
succeeded
+=
1
except
InvalidMutation
:
continue
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
2
,
10
)).
size
()
==
torch
.
Size
([
2
,
16
]))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment