Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8b2eb425
Unverified
Commit
8b2eb425
authored
Feb 18, 2022
by
v-fangdong
Committed by
GitHub
Feb 18, 2022
Browse files
Lightning implementation for retiarii oneshot nas (#4479)
parent
99818fba
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1147 additions
and
3 deletions
+1147
-3
nni/retiarii/oneshot/pytorch/__init__.py
nni/retiarii/oneshot/pytorch/__init__.py
+3
-1
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+318
-0
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+379
-0
nni/retiarii/oneshot/pytorch/enas.py
nni/retiarii/oneshot/pytorch/enas.py
+1
-1
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+178
-0
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+117
-1
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+151
-0
No files found.
nni/retiarii/oneshot/pytorch/__init__.py
View file @
8b2eb425
...
@@ -5,4 +5,6 @@ from .darts import DartsTrainer
...
@@ -5,4 +5,6 @@ from .darts import DartsTrainer
from
.enas
import
EnasTrainer
from
.enas
import
EnasTrainer
from
.proxyless
import
ProxylessTrainer
from
.proxyless
import
ProxylessTrainer
from
.random
import
SinglePathTrainer
,
RandomTrainer
from
.random
import
SinglePathTrainer
,
RandomTrainer
from
.utils
import
replace_input_choice
,
replace_layer_choice
from
.differentiable
import
DartsModule
,
ProxylessModule
,
SNASModule
from
.sampling
import
EnasModule
,
RandomSampleModule
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
nni/retiarii/oneshot/pytorch/base_lightning.py
0 → 100644
View file @
8b2eb425
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
pytorch_lightning
as
pl
import
torch.optim
as
optim
import
torch.nn
as
nn
from
torch.optim.lr_scheduler
import
_LRScheduler
def
_replace_module_with_type
(
root_module
,
replace_dict
,
modules
):
"""
Replace xxxChoice in user's model with NAS modules.
Parameters
----------
root_module : nn.Module
User-defined module with xxxChoice in it. In fact, since this method is called in the ``__init__`` of
``BaseOneShotLightningModule``, this will be a pl.LightningModule.
replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]]
Functions to replace xxxChoice modules. Keys should be xxxChoice type and values should be a
function that return an nn.module.
modules : List[nn.Module]
The replace result. This is also the return value of this function.
Returns
----------
modules : List[nn.Module]
The replace result.
"""
if
modules
is
None
:
modules
=
[]
def
apply
(
m
):
for
name
,
child
in
m
.
named_children
():
child_type
=
type
(
child
)
if
child_type
in
replace_dict
.
keys
():
setattr
(
m
,
name
,
replace_dict
[
child_type
](
child
))
modules
.
append
((
child
.
key
,
getattr
(
m
,
name
)))
else
:
apply
(
child
)
apply
(
root_module
)
return
modules
class
BaseOneShotLightningModule
(
pl
.
LightningModule
):
"""
The base class for all one-shot NAS modules. Essential function such as preprocessing user's model, redirecting lightning
hooks for user's model, configuring optimizers and exporting NAS result are implemented in this class.
Attributes
----------
nas_modules : List[nn.Module]
The replace result of a specific NAS method. xxxChoice will be replaced with some other modules with respect to the
NAS method.
Parameters
----------
base_model : pl.LightningModule
The evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
"""
automatic_optimization
=
False
def
__init__
(
self
,
base_model
,
custom_replace_dict
=
None
):
super
().
__init__
()
assert
isinstance
(
base_model
,
pl
.
LightningModule
)
self
.
model
=
base_model
# replace xxxChoice with respect to NAS alg
# replaced modules are stored in self.nas_modules
self
.
nas_modules
=
[]
choice_replace_dict
=
self
.
default_replace_dict
if
custom_replace_dict
is
not
None
:
for
k
,
v
in
custom_replace_dict
.
items
():
assert
isinstance
(
v
,
nn
.
Module
)
choice_replace_dict
[
k
]
=
v
_replace_module_with_type
(
self
.
model
,
choice_replace_dict
,
self
.
nas_modules
)
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
def
training_step
(
self
,
batch
,
batch_idx
):
# You can use self.architecture_optimizers or self.user_optimizers to get optimizers in
# your own training step.
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
def
configure_optimizers
(
self
):
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite configure_architecture_optimizers if architecture optimizers are needed in your NAS algorithm.
By now ``self.model`` is currently a :class:`nni.retiarii.evaluator.pytorch.lightning._SupervisedLearningModule`
and it only returns 1 optimizer. But for extendibility, codes for other return value types are also implemented.
"""
# pylint: disable=assignment-from-none
arc_optimizers
=
self
.
configure_architecture_optimizers
()
if
arc_optimizers
is
None
:
return
self
.
model
.
configure_optimizers
()
if
isinstance
(
arc_optimizers
,
optim
.
Optimizer
):
arc_optimizers
=
[
arc_optimizers
]
self
.
arc_optim_count
=
len
(
arc_optimizers
)
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
self
.
trainer
.
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
lr_schedulers
=
self
.
trainer
.
_configure_schedulers
(
lr_schedulers
,
monitor
,
not
self
.
automatic_optimization
)
if
any
(
sch
[
"scheduler"
].
optimizer
not
in
w_optimizers
for
sch
in
lr_schedulers
):
raise
Exception
(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
# variables used to handle optimizer frequency
self
.
cur_optimizer_step
=
0
self
.
cur_optimizer_index
=
0
return
arc_optimizers
+
w_optimizers
,
lr_schedulers
def
on_train_start
(
self
):
self
.
model
.
trainer
=
self
.
trainer
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_train_start
()
def
on_train_end
(
self
):
return
self
.
model
.
on_train_end
()
def
on_fit_start
(
self
):
return
self
.
model
.
on_train_start
()
def
on_fit_end
(
self
):
return
self
.
model
.
on_train_end
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
unused
)
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
unused
)
def
on_epoch_start
(
self
):
return
self
.
model
.
on_epoch_start
()
def
on_epoch_end
(
self
):
return
self
.
model
.
on_epoch_end
()
def
on_train_epoch_start
(
self
):
return
self
.
model
.
on_train_epoch_start
()
def
on_train_epoch_end
(
self
):
return
self
.
model
.
on_train_epoch_end
()
def
on_before_backward
(
self
,
loss
):
return
self
.
model
.
on_before_backward
(
loss
)
def
on_after_backward
(
self
):
return
self
.
model
.
on_after_backward
()
def
configure_gradient_clipping
(
self
,
optimizer
,
optimizer_idx
,
gradient_clip_val
=
None
,
gradient_clip_algorithm
=
None
):
return
self
.
model
.
configure_gradient_clipping
(
optimizer
,
optimizer_idx
,
gradient_clip_val
,
gradient_clip_algorithm
)
def
configure_architecture_optimizers
(
self
):
"""
Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here
if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers.
Returns
----------
arc_optimizers : List[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return
None
@
property
def
default_replace_dict
(
self
):
"""
Default xxxChoice replace dict. This is called in ``__init__`` to get the default replace functions for your NAS algorithm.
Note that your default replace functions may be overridden by user-defined custom_replace_dict.
Returns
----------
replace_dict : Dict[Type, Callable[nn.Module, nn.Module]]
Same as ``custom_replace_dict`` in ``__init__``, but this will be overridden if users define their own replace functions.
"""
replace_dict
=
{}
return
replace_dict
def
call_lr_schedulers
(
self
,
batch_index
):
"""
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
by this class, you can use this function to make schedulers behave as they were automatically handled by the lightning trainer.
Parameters
----------
batch_idx : int
batch index
"""
def
apply
(
lr_scheduler
):
# single scheduler is called every epoch
if
isinstance
(
lr_scheduler
,
_LRScheduler
)
and
\
self
.
trainer
.
is_last_batch
:
lr_schedulers
.
step
()
# lr_scheduler_config is called as configured
elif
isinstance
(
lr_scheduler
,
dict
):
interval
=
lr_scheduler
[
'interval'
]
frequency
=
lr_scheduler
[
'frequency'
]
if
(
interval
==
'step'
and
batch_index
%
frequency
==
0
)
or
\
(
interval
==
'epoch'
and
self
.
trainer
.
is_last_batch
and
(
self
.
trainer
.
current_epoch
+
1
)
%
frequency
==
0
):
lr_scheduler
.
step
()
lr_schedulers
=
self
.
lr_schedulers
()
if
isinstance
(
lr_schedulers
,
list
):
for
lr_scheduler
in
lr_schedulers
:
apply
(
lr_scheduler
)
else
:
apply
(
lr_schedulers
)
def
call_user_optimizers
(
self
,
method
):
"""
Function that imitates lightning trainer's behaviour of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Parameters
----------
method : str
Method to call. Only 'step' and 'zero_grad' are supported now.
"""
def
apply_method
(
optimizer
,
method
):
if
method
==
'step'
:
optimizer
.
step
()
elif
method
==
'zero_grad'
:
optimizer
.
zero_grad
()
optimizers
=
self
.
user_optimizers
if
optimizers
is
None
:
return
if
len
(
self
.
frequencies
)
>
0
:
self
.
cur_optimizer_step
+=
1
if
self
.
frequencies
[
self
.
cur_optimizer_index
]
==
self
.
cur_optimizer_step
:
self
.
cur_optimizer_step
=
0
self
.
cur_optimizer_index
=
self
.
cur_optimizer_index
+
1
\
if
self
.
cur_optimizer_index
+
1
<
len
(
optimizers
)
\
else
0
apply_method
(
optimizers
[
self
.
cur_optimizer_index
],
method
)
else
:
for
optimizer
in
optimizers
:
apply_method
(
optimizer
,
method
)
@
property
def
architecture_optimizers
(
self
):
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in ``training_step``.
Returns
----------
opts : List[Optimizer], Optimizer, None
Architecture optimizers defined in ``configure_architecture_optimizers``. This will be None if there is no
architecture optimizers.
"""
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
arc_opts
=
opts
[:
self
.
arc_optim_count
]
if
len
(
arc_opts
)
==
1
:
arc_opts
=
arc_opts
[
0
]
return
arc_opts
# If there is only 1 optimizer and it is the architecture optimizer
if
self
.
arc_optim_count
==
1
:
return
opts
return
None
@
property
def
user_optimizers
(
self
):
"""
Get user optimizers from all optimizers. Use this to get user optimizers in ``training step``.
Returns
----------
opts : List[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
return
opts
[
self
.
arc_optim_count
:]
# If there is only 1 optimizer and no architecture optimizer
if
self
.
arc_optim_count
==
0
:
return
opts
return
None
def
export
(
self
):
"""
Export the NAS result, ideally the best choice of each nas_modules.
You may implement an ``export`` method for your customized nas_module.
Returns
--------
result : Dict[str, int]
Keys are names of nas_modules, and values are the choice indices of them.
"""
result
=
{}
for
name
,
module
in
self
.
nas_modules
:
if
name
not
in
result
:
result
[
name
]
=
module
.
export
()
return
result
nni/retiarii/oneshot/pytorch/differentiable.py
0 → 100644
View file @
8b2eb425
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
.base_lightning
import
BaseOneShotLightningModule
class
DartsLayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
layer_choice
):
super
(
DartsLayerChoice
,
self
).
__init__
()
self
.
name
=
layer_choice
.
label
self
.
op_choices
=
nn
.
ModuleDict
(
OrderedDict
([(
name
,
layer_choice
[
name
])
for
name
in
layer_choice
.
names
]))
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
op_choices
))
*
1e-3
)
def
forward
(
self
,
*
args
,
**
kwargs
):
op_results
=
torch
.
stack
([
op
(
*
args
,
**
kwargs
)
for
op
in
self
.
op_choices
.
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
return
torch
.
sum
(
op_results
*
F
.
softmax
(
self
.
alpha
,
-
1
).
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
yield
p
def
named_parameters
(
self
,
recurse
=
False
):
for
name
,
p
in
super
(
DartsLayerChoice
,
self
).
named_parameters
():
if
name
==
'alpha'
:
continue
yield
name
,
p
def
export
(
self
):
return
list
(
self
.
op_choices
.
keys
())[
torch
.
argmax
(
self
.
alpha
).
item
()]
class
DartsInputChoice
(
nn
.
Module
):
def
__init__
(
self
,
input_choice
):
super
(
DartsInputChoice
,
self
).
__init__
()
self
.
name
=
input_choice
.
label
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1e-3
)
self
.
n_chosen
=
input_choice
.
n_chosen
or
1
def
forward
(
self
,
inputs
):
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
return
torch
.
sum
(
inputs
*
F
.
softmax
(
self
.
alpha
,
-
1
).
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
yield
p
def
named_parameters
(
self
,
recurse
=
False
):
for
name
,
p
in
super
(
DartsInputChoice
,
self
).
named_parameters
():
if
name
==
'alpha'
:
continue
yield
name
,
p
def
export
(
self
):
return
torch
.
argsort
(
-
self
.
alpha
).
cpu
().
numpy
().
tolist
()[:
self
.
n_chosen
]
class
DartsModule
(
BaseOneShotLightningModule
):
"""
The DARTS module. Each iteration consists of 2 training phases. The phase 1 is architecture step, in which model parameters are
frozen and the architecture parameters are trained. The phase 2 is model step, in which architecture parameters are frozen and
model parameters are trained. See [darts] for details.
The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.
Reference
----------
.. [darts] H. Liu, K. Simonyan, and Y. Yang, “DARTS: Differentiable Architecture Search,” presented at the
International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=S1eYHoC5FX
"""
def
training_step
(
self
,
batch
,
batch_idx
):
# grad manually
arc_optim
=
self
.
architecture_optimizers
# The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch
,
val_batch
=
batch
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
# See code of those methods for details.
self
.
_resample
()
arc_optim
.
zero_grad
()
arc_step_loss
=
self
.
model
.
training_step
(
val_batch
,
2
*
batch_idx
)
if
isinstance
(
arc_step_loss
,
dict
):
arc_step_loss
=
arc_step_loss
[
'loss'
]
self
.
manual_backward
(
arc_step_loss
)
self
.
finalize_grad
()
arc_optim
.
step
()
# phase 2: model step
self
.
_resample
()
self
.
call_user_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
call_user_optimizers
(
'step'
)
self
.
call_lr_schedulers
(
batch_idx
)
return
loss_and_metrics
def
_resample
(
self
):
# Note: This hook is kept for following darts-based NAS algs.
pass
def
finalize_grad
(
self
):
# Note: This hook is currently kept for Proxyless NAS.
pass
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
DartsLayerChoice
,
InputChoice
:
DartsInputChoice
}
def
configure_architecture_optimizers
(
self
):
# The alpha in DartsXXXChoices is the architecture parameter of DARTS. All alphas share one optimizer.
ctrl_params
=
{}
for
_
,
m
in
self
.
nas_modules
:
if
m
.
name
in
ctrl_params
:
assert
m
.
alpha
.
size
()
==
ctrl_params
[
m
.
name
].
size
(),
'Size of parameters with the same label should be same.'
m
.
alpha
=
ctrl_params
[
m
.
name
]
else
:
ctrl_params
[
m
.
name
]
=
m
.
alpha
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
ctrl_params
.
values
()),
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
return
ctrl_optim
class
_ArchGradientFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
binary_gates
,
run_func
,
backward_func
):
ctx
.
run_func
=
run_func
ctx
.
backward_func
=
backward_func
detached_x
=
x
.
detach
()
detached_x
.
requires_grad
=
x
.
requires_grad
with
torch
.
enable_grad
():
output
=
run_func
(
detached_x
)
ctx
.
save_for_backward
(
detached_x
,
output
)
return
output
.
data
@
staticmethod
def
backward
(
ctx
,
grad_output
):
detached_x
,
output
=
ctx
.
saved_tensors
grad_x
=
torch
.
autograd
.
grad
(
output
,
detached_x
,
grad_output
,
only_inputs
=
True
)
# compute gradients w.r.t. binary_gates
binary_grads
=
ctx
.
backward_func
(
detached_x
.
data
,
output
.
data
,
grad_output
.
data
)
return
grad_x
[
0
],
binary_grads
,
None
,
None
class
ProxylessLayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
ops
):
super
(
ProxylessLayerChoice
,
self
).
__init__
()
self
.
ops
=
nn
.
ModuleList
(
ops
)
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
ops
))
*
1E-3
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
ops
))
*
1E-3
)
self
.
sampled
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
def
run_function
(
ops
,
active_id
,
**
kwargs
):
def
forward
(
_x
):
return
ops
[
active_id
](
_x
,
**
kwargs
)
return
forward
def
backward_function
(
ops
,
active_id
,
binary_gates
,
**
kwargs
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
len
(
ops
)):
if
k
!=
active_id
:
out_k
=
ops
[
k
](
_x
.
data
,
**
kwargs
)
else
:
out_k
=
_output
.
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
assert
len
(
args
)
==
1
x
=
args
[
0
]
return
_ArchGradientFunction
.
apply
(
x
,
self
.
_binary_gates
,
run_function
(
self
.
ops
,
self
.
sampled
,
**
kwargs
),
backward_function
(
self
.
ops
,
self
.
sampled
,
self
.
_binary_gates
,
**
kwargs
)
)
return
super
().
forward
(
*
args
,
**
kwargs
)
def
resample
(
self
):
probs
=
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
sampled
=
sample
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
sample
]
=
1.0
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
with
torch
.
no_grad
():
if
self
.
alpha
.
grad
is
None
:
self
.
alpha
.
grad
=
torch
.
zeros_like
(
self
.
alpha
.
data
)
probs
=
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
for
i
in
range
(
len
(
self
.
ops
)):
for
j
in
range
(
len
(
self
.
ops
)):
self
.
alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
def
export
(
self
):
return
torch
.
argmax
(
self
.
alpha
).
item
()
def
export_prob
(
self
):
return
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
class
ProxylessInputChoice
(
nn
.
Module
):
def
__init__
(
self
,
input_choice
):
super
().
__init__
()
self
.
num_input_candidates
=
input_choice
.
n_candidates
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1E-3
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1E-3
)
self
.
sampled
=
None
def
forward
(
self
,
inputs
):
if
self
.
training
:
def
run_function
(
active_sample
):
return
lambda
x
:
x
[
active_sample
]
def
backward_function
(
binary_gates
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
self
.
num_input_candidates
):
out_k
=
_x
[
k
].
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
inputs
=
torch
.
stack
(
inputs
,
0
)
return
_ArchGradientFunction
.
apply
(
inputs
,
self
.
_binary_gates
,
run_function
(
self
.
sampled
),
backward_function
(
self
.
_binary_gates
)
)
return
super
().
forward
(
inputs
)
def
resample
(
self
,
sample
=
None
):
if
sample
is
None
:
probs
=
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
sampled
=
sample
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
sample
]
=
1.0
return
self
.
sampled
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
with
torch
.
no_grad
():
if
self
.
alpha
.
grad
is
None
:
self
.
alpha
.
grad
=
torch
.
zeros_like
(
self
.
alpha
.
data
)
probs
=
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
for
i
in
range
(
self
.
num_input_candidates
):
for
j
in
range
(
self
.
num_input_candidates
):
self
.
alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
class
ProxylessModule
(
DartsModule
):
"""
The Proxyless Module. This is a darts-based method that resamples the architecture to reduce memory consumption.
The Proxyless Module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.
Reference
----------
.. [proxyless] H. Cai, L. Zhu, and S. Han, “ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware,” presented
at the International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=HylVB3AqYm
"""
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
ProxylessLayerChoice
,
InputChoice
:
ProxylessInputChoice
}
def
configure_architecture_optimizers
(
self
):
ctrl_optim
=
torch
.
optim
.
Adam
([
m
.
alpha
for
_
,
m
in
self
.
nas_modules
],
3.e-4
,
weight_decay
=
0
,
betas
=
(
0
,
0.999
),
eps
=
1e-8
)
return
ctrl_optim
def
_resample
(
self
):
for
_
,
m
in
self
.
nas_modules
:
m
.
resample
()
def
finalize_grad
(
self
):
for
_
,
m
in
self
.
nas_modules
:
m
.
finalize_grad
()
class
SNASLayerChoice
(
DartsLayerChoice
):
def
forward
(
self
,
*
args
,
**
kwargs
):
self
.
one_hot
=
F
.
gumbel_softmax
(
self
.
alpha
,
self
.
temp
)
op_results
=
torch
.
stack
([
op
(
*
args
,
**
kwargs
)
for
op
in
self
.
op_choices
.
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
yhat
=
torch
.
sum
(
op_results
*
self
.
one_hot
.
view
(
*
alpha_shape
),
0
)
return
yhat
class
SNASInputChoice
(
DartsInputChoice
):
def
forward
(
self
,
inputs
):
self
.
one_hot
=
F
.
gumbel_softmax
(
self
.
alpha
,
self
.
temp
)
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
yhat
=
torch
.
sum
(
inputs
*
self
.
one_hot
.
view
(
*
alpha_shape
),
0
)
return
yhat
class
SNASModule
(
DartsModule
):
"""
The SNAS Module. This is a darts-based method that uses gumble-softmax to simulate one-hot distribution.
The SNAS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.
Parameters
----------
base_model : pl.LightningModule
The evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
gumble_temperature : float
The initial temperature used in gumble-softmax.
use_temp_anneal : bool
True: a linear annealing will be applied to gumble_temperature. False: run at a fixed temperature. See [snas] for details.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
Reference
----------
.. [snas] S. Xie, H. Zheng, C. Liu, and L. Lin, “SNAS: stochastic neural architecture search,” presented at the
International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=rylqooRqK7
"""
def
__init__
(
self
,
base_model
,
gumble_temperature
=
1.
,
use_temp_anneal
=
False
,
min_temp
=
.
33
,
custom_replace_dict
=
None
):
super
().
__init__
(
base_model
,
custom_replace_dict
)
self
.
temp
=
gumble_temperature
self
.
init_temp
=
gumble_temperature
self
.
use_temp_anneal
=
use_temp_anneal
self
.
min_temp
=
min_temp
def
on_epoch_start
(
self
):
if
self
.
use_temp_anneal
:
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
for
_
,
nas_module
in
self
.
nas_modules
:
nas_module
.
temp
=
self
.
temp
return
self
.
model
.
on_epoch_start
()
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
SNASLayerChoice
,
InputChoice
:
SNASInputChoice
}
nni/retiarii/oneshot/pytorch/enas.py
View file @
8b2eb425
...
@@ -145,7 +145,7 @@ class ReinforceController(nn.Module):
...
@@ -145,7 +145,7 @@ class ReinforceController(nn.Module):
else
:
else
:
self
.
_inputs
=
torch
.
zeros
(
1
,
self
.
lstm_size
,
device
=
self
.
embedding
[
field
.
name
].
weight
.
device
)
self
.
_inputs
=
torch
.
zeros
(
1
,
self
.
lstm_size
,
device
=
self
.
embedding
[
field
.
name
].
weight
.
device
)
sampled
=
sampled
.
detach
().
numpy
().
tolist
()
sampled
=
sampled
.
detach
().
cpu
().
numpy
().
tolist
()
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
...
...
nni/retiarii/oneshot/pytorch/sampling.py
0 → 100644
View file @
8b2eb425
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
random
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
nni.retiarii.nn.pytorch.api
import
LayerChoice
,
InputChoice
from
.random
import
PathSamplingLayerChoice
,
PathSamplingInputChoice
from
.base_lightning
import
BaseOneShotLightningModule
from
.enas
import
ReinforceController
,
ReinforceField
class
EnasModule
(
BaseOneShotLightningModule
):
"""
The ENAS module. There are 2 steps in an epoch. 1: training model parameters. 2: training ENAS RL agent. The agent will produce
a sample of model architecture to get the best reward.
The ENASModule should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.
Parameters
----------
base_model : pl.LightningModule
he evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
grad_clip : float
Gradient clipping value.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
Reference
----------
.. [enas] H. Pham, M. Guan, B. Zoph, Q. Le, and J. Dean, “Efficient Neural Architecture Search via Parameters Sharing,”
in Proceedings of the 35th International Conference on Machine Learning, Jul. 2018, pp. 4095-4104.
Available: https://proceedings.mlr.press/v80/pham18a.html
"""
def
__init__
(
self
,
base_model
,
ctrl_kwargs
=
None
,
entropy_weight
=
1e-4
,
skip_weight
=
.
8
,
baseline_decay
=
.
999
,
ctrl_steps_aggregate
=
20
,
grad_clip
=
0
,
custom_replace_dict
=
None
):
super
().
__init__
(
base_model
,
custom_replace_dict
)
self
.
nas_fields
=
[
ReinforceField
(
name
,
len
(
module
),
isinstance
(
module
,
PathSamplingLayerChoice
)
or
module
.
n_chosen
==
1
)
for
name
,
module
in
self
.
nas_modules
]
self
.
controller
=
ReinforceController
(
self
.
nas_fields
,
**
(
ctrl_kwargs
or
{}))
self
.
entropy_weight
=
entropy_weight
self
.
skip_weight
=
skip_weight
self
.
baseline_decay
=
baseline_decay
self
.
baseline
=
0.
self
.
ctrl_steps_aggregate
=
ctrl_steps_aggregate
self
.
grad_clip
=
grad_clip
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
PathSamplingLayerChoice
,
InputChoice
:
PathSamplingInputChoice
}
def
training_step
(
self
,
batch
,
batch_idx
):
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from.
batch
,
source
=
batch
if
source
==
'train'
:
# step 1: train model params
self
.
_resample
()
self
.
call_user_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
call_user_optimizers
(
'step'
)
return
loss_and_metrics
if
source
==
'val'
:
# step 2: train ENAS agent
x
,
y
=
batch
arc_opt
=
self
.
architecture_optimizers
arc_opt
.
zero_grad
()
self
.
_resample
()
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
# use the default metric of self.model as reward function
if
len
(
self
.
model
.
metrics
)
==
1
:
_
,
metric
=
next
(
iter
(
self
.
model
.
metrics
.
items
()))
else
:
if
'default'
not
in
self
.
model
.
metrics
.
keys
():
raise
KeyError
(
'model.metrics should contain a ``default`` key when'
\
'there are multiple metrics'
)
metric
=
self
.
model
.
metrics
[
'default'
]
reward
=
metric
(
logits
,
y
)
if
self
.
entropy_weight
:
reward
=
reward
+
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
rnn_step_loss
=
self
.
controller
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
rnn_step_loss
=
rnn_step_loss
+
self
.
skip_weight
*
self
.
controller
.
sample_skip_penalty
rnn_step_loss
=
rnn_step_loss
/
self
.
ctrl_steps_aggregate
self
.
manual_backward
(
rnn_step_loss
)
if
(
batch_idx
+
1
)
%
self
.
ctrl_steps_aggregate
==
0
:
if
self
.
grad_clip
>
0
:
nn
.
utils
.
clip_grad_norm_
(
self
.
controller
.
parameters
(),
self
.
grad_clip
)
arc_opt
.
step
()
arc_opt
.
zero_grad
()
def
_resample
(
self
):
"""
Resample the architecture as ENAS result. This doesn't require an ``export`` method in nas_modules to work.
"""
result
=
self
.
controller
.
resample
()
for
name
,
module
in
self
.
nas_modules
:
module
.
sampled
=
result
[
name
]
def
export
(
self
):
self
.
controller
.
eval
()
with
torch
.
no_grad
():
return
self
.
controller
.
resample
()
class
RandomSampleModule
(
BaseOneShotLightningModule
):
"""
Random Sampling NAS Algorithm. In each epoch, model parameters are trained after a uniformly random sampling of each choice.
The training result is also a random sample of the search space.
Parameters
----------
base_model : pl.LightningModule
he evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
"""
automatic_optimization
=
True
def
training_step
(
self
,
batch
,
batch_idx
):
self
.
_resample
()
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
PathSamplingLayerChoice
,
InputChoice
:
PathSamplingInputChoice
}
def
_resample
(
self
):
"""
Resample the architecture as RandomSample result. This is simply a uniformly sampling that doesn't require an ``export``
method in nas_modules to work.
"""
result
=
{}
for
name
,
module
in
self
.
nas_modules
:
if
name
not
in
result
:
result
[
name
]
=
random
.
randint
(
0
,
len
(
module
)
-
1
)
module
.
sampled
=
result
[
name
]
return
result
def
export
(
self
):
return
self
.
_resample
()
nni/retiarii/oneshot/pytorch/utils.py
View file @
8b2eb425
...
@@ -6,6 +6,7 @@ from collections import OrderedDict
...
@@ -6,6 +6,7 @@ from collections import OrderedDict
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
DataLoader
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
...
@@ -127,7 +128,6 @@ class AverageMeter:
...
@@ -127,7 +128,6 @@ class AverageMeter:
def
_replace_module_with_type
(
root_module
,
init_fn
,
type_name
,
modules
):
def
_replace_module_with_type
(
root_module
,
init_fn
,
type_name
,
modules
):
if
modules
is
None
:
if
modules
is
None
:
modules
=
[]
modules
=
[]
def
apply
(
m
):
def
apply
(
m
):
for
name
,
child
in
m
.
named_children
():
for
name
,
child
in
m
.
named_children
():
if
isinstance
(
child
,
type_name
):
if
isinstance
(
child
,
type_name
):
...
@@ -180,3 +180,119 @@ def replace_input_choice(root_module, init_fn, modules=None):
...
@@ -180,3 +180,119 @@ def replace_input_choice(root_module, init_fn, modules=None):
A list from layer choice keys (names) and replaced modules.
A list from layer choice keys (names) and replaced modules.
"""
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
InputChoice
,
nn
.
InputChoice
),
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
InputChoice
,
nn
.
InputChoice
),
modules
)
class
InterleavedTrainValDataLoader
(
DataLoader
):
"""
Dataloader that yields both train data and validation data in a batch, with an order of (train_batch, val_batch). The shorter
one will be upsampled (repeated) to the length of the longer one, and the tail of the last repeat will be dropped. This enables
users to train both model parameters and architecture parameters in parallel in an epoch.
Some NAS algorithms, i.e. DARTS and Proxyless, require this type of dataloader.
Parameters
----------
train_data : DataLoader
training dataloader
val_data : DataLoader
validation dataloader
Example
--------
Fit your dataloaders into a parallel one.
>>> para_loader = InterleavedTrainValDataLoader(train_dataloader, val_dataloader)
Then you can use the ``para_loader`` as a normal training loader.
"""
def
__init__
(
self
,
train_dataloader
,
val_dataloader
):
self
.
train_loader
=
train_dataloader
self
.
val_loader
=
val_dataloader
self
.
equal_len
=
len
(
train_dataloader
)
==
len
(
val_dataloader
)
self
.
train_longer
=
len
(
train_dataloader
)
>
len
(
val_dataloader
)
super
().
__init__
(
None
)
def
__iter__
(
self
):
self
.
train_iter
=
iter
(
self
.
train_loader
)
self
.
val_iter
=
iter
(
self
.
val_loader
)
return
self
def
__next__
(
self
):
try
:
train_batch
=
next
(
self
.
train_iter
)
except
StopIteration
:
# training data is used up
if
self
.
equal_len
or
self
.
train_longer
:
# if training is the longger one or equal, stop iteration
raise
StopIteration
()
# if training is the shorter one, upsample it
self
.
train_iter
=
iter
(
self
.
train_loader
)
train_batch
=
next
(
self
.
train_iter
)
try
:
val_batch
=
next
(
self
.
val_iter
)
except
StopIteration
:
# validation data is used up
if
not
self
.
train_longer
:
# if validation is the longger one (the equal condition is
# covered above), stop iteration
raise
StopIteration
()
# if validation is the shorter one, upsample it
self
.
val_iter
=
iter
(
self
.
val_loader
)
val_batch
=
next
(
self
.
val_iter
)
return
train_batch
,
val_batch
def
__len__
(
self
)
->
int
:
return
max
(
len
(
self
.
train_loader
),
len
(
self
.
val_loader
))
class
ConcatenateTrainValDataLoader
(
DataLoader
):
"""
Dataloader that yields validation data after training data in an epoch. You will get a batch with the form of (batch, source) in the
training step, where ``source`` is a string which is either 'train' or 'val', indicating which dataloader the batch comes from. This
enables users to train model parameters first in an epoch, and then train architecture parameters.
Some NAS algorithms, i.e. ENAS, may require this type of dataloader.
Parameters
----------
train_data : DataLoader
training dataloader
val_data : DataLoader
validation dataloader
Warnings
----------
If you set ``limit_train_batches`` of the trainer, the validation batches may be skipped.
Consider downsampling the train dataset and the validation dataset instead if you want to shorten the length of data.
Example
--------
Fit your dataloaders into a concatenated one.
>>> concat_loader = ConcatenateTrainValDataLoader(train_dataloader, val_datalodaer)
Then you can use the ``concat_loader`` as a normal training loader.
"""
def
__init__
(
self
,
train_dataloader
,
val_dataloader
):
self
.
train_loader
=
train_dataloader
self
.
val_loader
=
val_dataloader
super
().
__init__
(
None
)
def
__iter__
(
self
):
self
.
cur_iter
=
iter
(
self
.
train_loader
)
self
.
source
=
'train'
return
self
def
__next__
(
self
):
try
:
batch
=
next
(
self
.
cur_iter
)
except
StopIteration
:
# training data is used up, change to validation data
if
self
.
source
==
'train'
:
self
.
cur_iter
=
iter
(
self
.
val_loader
)
self
.
source
=
'val'
return
next
(
self
)
raise
StopIteration
()
else
:
return
batch
,
self
.
source
def
__len__
(
self
):
return
len
(
self
.
train_loader
)
+
len
(
self
.
val_loader
)
test/ut/retiarii/test_oneshot.py
0 → 100644
View file @
8b2eb425
import
argparse
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
pytorch_lightning
as
pl
import
pytest
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
from
torch.utils.data.sampler
import
RandomSampler
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
nni.retiarii.oneshot.pytorch
import
(
ConcatenateTrainValDataLoader
,
DartsModule
,
EnasModule
,
SNASModule
,
InterleavedTrainValDataLoader
,
ProxylessModule
,
RandomSampleModule
)
class
DepthwiseSeparableConv
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
):
super
().
__init__
()
self
.
depthwise
=
nn
.
Conv2d
(
in_ch
,
in_ch
,
kernel_size
=
3
,
groups
=
in_ch
)
self
.
pointwise
=
nn
.
Conv2d
(
in_ch
,
out_ch
,
kernel_size
=
1
)
def
forward
(
self
,
x
):
return
self
.
pointwise
(
self
.
depthwise
(
x
))
class
Net
(
pl
.
LightningModule
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
3
,
1
)
self
.
conv2
=
LayerChoice
([
nn
.
Conv2d
(
32
,
64
,
3
,
1
),
DepthwiseSeparableConv
(
32
,
64
)
])
self
.
dropout1
=
nn
.
Dropout
(.
25
)
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
self
.
dropout_choice
=
InputChoice
(
2
,
1
)
self
.
fc
=
LayerChoice
([
nn
.
Sequential
(
nn
.
Linear
(
9216
,
64
),
nn
.
ReLU
(),
nn
.
Linear
(
64
,
10
),
),
nn
.
Sequential
(
nn
.
Linear
(
9216
,
128
),
nn
.
ReLU
(),
nn
.
Linear
(
128
,
10
),
),
nn
.
Sequential
(
nn
.
Linear
(
9216
,
256
),
nn
.
ReLU
(),
nn
.
Linear
(
256
,
10
),
)
])
self
.
rpfc
=
nn
.
Linear
(
10
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x1
=
torch
.
flatten
(
self
.
dropout1
(
x
),
1
)
x2
=
torch
.
flatten
(
self
.
dropout2
(
x
),
1
)
x
=
self
.
dropout_choice
([
x1
,
x2
])
x
=
self
.
fc
(
x
)
x
=
self
.
rpfc
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
prepare_model_data
():
base_model
=
Net
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
10
))
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
10
))
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
trainer_kwargs
=
{
'max_epochs'
:
1
}
return
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_darts
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
darts_model
=
DartsModule
(
cls
.
module
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
darts_model
,
para_loader
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_proxyless
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
proxyless_model
=
ProxylessModule
(
cls
.
module
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
proxyless_model
,
para_loader
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_enas
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
enas_model
=
EnasModule
(
cls
.
module
)
concat_loader
=
ConcatenateTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
enas_model
,
concat_loader
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_random
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
random_model
=
RandomSampleModule
(
cls
.
module
)
cls
.
trainer
.
fit
(
random_model
,
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_snas
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
proxyless_model
=
SNASModule
(
cls
.
module
,
1
,
use_temp_anneal
=
True
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
proxyless_model
,
para_loader
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--exp'
,
type
=
str
,
default
=
'all'
,
metavar
=
'E'
,
help
=
'exp to run, default = all'
)
args
=
parser
.
parse_args
()
if
args
.
exp
==
'all'
:
test_darts
()
test_proxyless
()
test_enas
()
test_random
()
test_snas
()
else
:
globals
()[
f
'test_
{
args
.
exp
}
'
]()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment