Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f77db747
Unverified
Commit
f77db747
authored
Aug 15, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 15, 2022
Browse files
Enhancement of one-shot NAS (v2.9) (#5049)
parent
125ec21f
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
730 additions
and
399 deletions
+730
-399
nni/nas/hub/pytorch/modules/nasbench201.py
nni/nas/hub/pytorch/modules/nasbench201.py
+1
-1
nni/nas/hub/pytorch/nasbench201.py
nni/nas/hub/pytorch/nasbench201.py
+4
-1
nni/nas/oneshot/pytorch/base_lightning.py
nni/nas/oneshot/pytorch/base_lightning.py
+213
-203
nni/nas/oneshot/pytorch/differentiable.py
nni/nas/oneshot/pytorch/differentiable.py
+46
-27
nni/nas/oneshot/pytorch/enas.py
nni/nas/oneshot/pytorch/enas.py
+8
-3
nni/nas/oneshot/pytorch/sampling.py
nni/nas/oneshot/pytorch/sampling.py
+62
-18
nni/nas/oneshot/pytorch/supermodule/_valuechoice_utils.py
nni/nas/oneshot/pytorch/supermodule/_valuechoice_utils.py
+2
-2
nni/nas/oneshot/pytorch/supermodule/base.py
nni/nas/oneshot/pytorch/supermodule/base.py
+11
-0
nni/nas/oneshot/pytorch/supermodule/differentiable.py
nni/nas/oneshot/pytorch/supermodule/differentiable.py
+76
-10
nni/nas/oneshot/pytorch/supermodule/operation.py
nni/nas/oneshot/pytorch/supermodule/operation.py
+8
-0
nni/nas/oneshot/pytorch/supermodule/proxyless.py
nni/nas/oneshot/pytorch/supermodule/proxyless.py
+140
-117
nni/nas/oneshot/pytorch/supermodule/sampling.py
nni/nas/oneshot/pytorch/supermodule/sampling.py
+5
-1
test/algo/nas/test_oneshot.py
test/algo/nas/test_oneshot.py
+50
-13
test/algo/nas/test_oneshot_proxyless.py
test/algo/nas/test_oneshot_proxyless.py
+77
-0
test/algo/nas/test_oneshot_supermodules.py
test/algo/nas/test_oneshot_supermodules.py
+27
-3
No files found.
nni/nas/hub/pytorch/modules/nasbench201.py
View file @
f77db747
...
...
@@ -70,7 +70,7 @@ class NasBench201Cell(nn.Module):
inp
=
in_features
if
j
==
0
else
out_features
op_choices
=
OrderedDict
([(
key
,
cls
(
inp
,
out_features
))
for
key
,
cls
in
op_candidates
.
items
()])
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
__
{
j
}
_
{
tid
}
'
))
# put __ here to be compatible with base engine
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
/
{
j
}
_
{
tid
}
'
))
self
.
layers
.
append
(
node_ops
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
nni/nas/hub/pytorch/nasbench201.py
View file @
f77db747
...
...
@@ -179,7 +179,7 @@ class NasBench201(nn.Module):
cell
=
ResNetBasicblock
(
C_prev
,
C_curr
,
2
)
else
:
ops
:
Dict
[
str
,
Callable
[[
int
,
int
],
nn
.
Module
]]
=
{
prim
:
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
for
prim
in
PRIMITIVES
prim
:
self
.
_make_op_factory
(
prim
)
for
prim
in
PRIMITIVES
}
cell
=
NasBench201Cell
(
ops
,
C_prev
,
C_curr
,
label
=
'cell'
)
self
.
cells
.
append
(
cell
)
...
...
@@ -192,6 +192,9 @@ class NasBench201(nn.Module):
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
classifier
=
nn
.
Linear
(
C_prev
,
self
.
num_labels
)
def
_make_op_factory
(
self
,
prim
):
return
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
def
forward
(
self
,
inputs
):
feature
=
self
.
stem
(
inputs
)
for
cell
in
self
.
cells
:
...
...
nni/nas/oneshot/pytorch/base_lightning.py
View file @
f77db747
...
...
@@ -5,23 +5,21 @@ from __future__ import annotations
import
warnings
from
itertools
import
chain
from
typing
import
Callable
,
Any
,
Dict
,
Union
,
Tuple
,
List
,
cast
from
typing
import
Callable
,
Any
,
Dict
,
Union
,
Tuple
,
cast
import
pytorch_lightning
as
pl
import
torch.optim
as
optim
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
import
nni.nas.nn.pytorch
as
nas_nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
is_traceable
from
nni.nas.nn.pytorch.choice
import
ValueChoiceX
from
nni.typehint
import
Literal
from
.supermodule.base
import
BaseSuperNetModule
__all__
=
[
'MANUAL_OPTIMIZATION_NOTE'
,
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
...
...
@@ -30,6 +28,22 @@ __all__ = [
]
MANUAL_OPTIMIZATION_NOTE
=
"""
.. warning::
The strategy, under the hood, creates a Lightning module that wraps the Lightning module defined in evaluator,
and enables `Manual optimization <https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html>`_,
although we assume **the inner evaluator has enabled automatic optimization**.
We call the optimizers and schedulers configured in evaluator, following the definition in Lightning at best effort,
but we make no guarantee that the behaviors are exactly same as automatic optimization.
We call :meth:`~BaseSuperNetModule.advance_optimization` and :meth:`~BaseSuperNetModule.advance_lr_schedulers`
to invoke the optimizers and schedulers configured in evaluators.
Moreover, some advanced features like gradient clipping will not be supported.
If you encounter any issues, please contact us by `creating an issue <https://github.com/microsoft/nni/issues>`_.
"""
MutationHook
=
Callable
[[
nn
.
Module
,
str
,
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]],
Union
[
nn
.
Module
,
bool
,
Tuple
[
nn
.
Module
,
bool
]]]
...
...
@@ -122,7 +136,7 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
nas_nn
.
LayerChoice
,
nas_nn
.
InputChoice
,
nas_nn
.
Repeat
,
# nas_nn.NasBench101Cell,
# FIXME: nasbench101 is moved to hub, can't check any more.
# nas_nn.NasBench101Cell,
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
...
...
@@ -156,8 +170,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
Extra mutation hooks to support customized mutation on primitives other than built-ins.
Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.
nas
.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.
nas
.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
:class:`~nni.
retiarii
.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.
retiarii
.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
...
...
@@ -177,21 +191,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
The returned arguments can be also one of the three kinds:
1. tuple of: :class:`~nni.
nas
.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
1. tuple of: :class:`~nni.
retiarii
.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
2. boolean,
3. :class:`~nni.
nas
.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
3. :class:`~nni.
retiarii
.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the following hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false.
If a none value appears on the place of
:class:`~nni.
nas
.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
:class:`~nni.
retiarii
.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
it means the hook suggests to
keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`~nni.
nas
.oneshot.pytorch.base_lightning.no_default_hook`.
An example of mutation hook is given in :func:`~nni.
retiarii
.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving
:class:`~nni.
nas
.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
:class:`~nni.
retiarii
.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list.
"""
...
...
@@ -295,236 +309,232 @@ class BaseOneShotLightningModule(pl.LightningModule):
result
.
update
(
module
.
export
(
memo
=
result
))
return
result
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
def
training_step
(
self
,
batch
,
batch_idx
):
"""This is the implementation of what happens in training loops of one-shot algos.
It usually calls ``self.model.training_step`` which implements the real training recipe of the users' model.
def
export_probs
(
self
)
->
dict
[
str
,
Any
]:
"""
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
Export the probability of every choice in the search space got chosen.
def
configure_optimizers
(
self
):
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.
.. note:: If such method of some modules is not implemented, they will be simply ignored.
For now :attr:`model` is tested against evaluators in :mod:`nni.nas.evaluator.pytorch.lightning`
and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented.
Returns
-------
dict
In most cases, keys are names of ``nas_modules`` suffixed with ``/`` and choice name.
Values are the probability / logits depending on the implementation.
"""
# pylint: disable=assignment-from-none
arc_optimizers
=
self
.
configure_architecture_optimizers
()
if
arc_optimizers
is
None
:
return
self
.
model
.
configure_optimizers
()
if
isinstance
(
arc_optimizers
,
optim
.
Optimizer
):
arc_optimizers
=
[
arc_optimizers
]
self
.
arc_optim_count
=
len
(
arc_optimizers
)
result
=
{}
for
module
in
self
.
nas_modules
:
try
:
result
.
update
(
module
.
export_probs
(
memo
=
result
))
except
NotImplementedError
:
warnings
.
warn
(
'Some super-modules you have used did not implement export_probs. You might find some logs are missing.'
,
UserWarning
)
return
result
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
try
:
# above v1.6
from
pytorch_lightning.core.optimizer
import
(
# pylint: disable=import-error
_configure_optimizers
,
# type: ignore
_configure_schedulers_automatic_opt
,
# type: ignore
_configure_schedulers_manual_opt
# type: ignore
)
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
# type: ignore
lr_schedulers
=
(
_configure_schedulers_automatic_opt
(
lr_schedulers
,
monitor
)
if
self
.
automatic_optimization
else
_configure_schedulers_manual_opt
(
lr_schedulers
)
)
except
ImportError
:
# under v1.5
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
self
.
trainer
.
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
# type: ignore
lr_schedulers
=
self
.
trainer
.
_configure_schedulers
(
lr_schedulers
,
monitor
,
not
self
.
automatic_optimization
)
# type: ignore
if
any
(
sch
[
"scheduler"
].
optimizer
not
in
w_optimizers
for
sch
in
lr_schedulers
):
# type: ignore
raise
Exception
(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
# variables used to handle optimizer frequency
self
.
cur_optimizer_step
=
0
self
.
cur_optimizer_index
=
0
return
arc_optimizers
+
w_optimizers
,
lr_schedulers
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
def
on_train_start
(
self
):
return
self
.
model
.
on_train_start
()
def
configure_optimizers
(
self
)
->
Any
:
"""
Transparently configure optimizers for the inner model,
unless one-shot algorithm has its own optimizer (via :meth:`configure_architecture_optimizers`),
in which case, the optimizer will be appended to the list.
def
on_train_end
(
self
):
return
self
.
model
.
on_train_end
()
The return value is still one of the 6 types defined in PyTorch-Lightning.
"""
arch_optimizers
=
self
.
configure_architecture_optimizers
()
or
[]
if
not
arch_optimizers
:
# no architecture optimizer available
return
self
.
model
.
configure_optimizers
()
def
on_fit_start
(
self
):
if
isinstance
(
arch_optimizers
,
optim
.
Optimizer
):
arch_optimizers
=
[
arch_optimizers
]
# Set the flag to True so that they can differ from other optimizers
for
optimizer
in
arch_optimizers
:
optimizer
.
is_arch_optimizer
=
True
# type: ignore
optim_conf
:
Any
=
self
.
model
.
configure_optimizers
()
# 0. optimizer is none
if
optim_conf
is
None
:
return
arch_optimizers
# 1. single optimizer
if
isinstance
(
optim_conf
,
Optimizer
):
return
[
optim_conf
]
+
arch_optimizers
# 2. two lists, optimizer + lr schedulers
if
(
isinstance
(
optim_conf
,
(
list
,
tuple
))
and
len
(
optim_conf
)
==
2
and
isinstance
(
optim_conf
[
0
],
list
)
and
all
(
isinstance
(
opt
,
Optimizer
)
for
opt
in
optim_conf
[
0
])
):
return
list
(
optim_conf
[
0
])
+
arch_optimizers
,
optim_conf
[
1
]
# 3. single dictionary
if
isinstance
(
optim_conf
,
dict
):
return
[
optim_conf
]
+
[{
'optimizer'
:
optimizer
}
for
optimizer
in
arch_optimizers
]
# 4. multiple dictionaries
if
isinstance
(
optim_conf
,
(
list
,
tuple
))
and
all
(
isinstance
(
d
,
dict
)
for
d
in
optim_conf
):
return
list
(
optim_conf
)
+
[{
'optimizer'
:
optimizer
}
for
optimizer
in
arch_optimizers
]
# 5. single list or tuple, multiple optimizer
if
isinstance
(
optim_conf
,
(
list
,
tuple
))
and
all
(
isinstance
(
opt
,
Optimizer
)
for
opt
in
optim_conf
):
return
list
(
optim_conf
)
+
arch_optimizers
# unknown configuration
warnings
.
warn
(
'Unknown optimizer configuration. Architecture optimizers will be ignored. Strategy might fail.'
,
UserWarning
)
return
optim_conf
def
setup
(
self
,
stage
=
None
):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
# type: ignore
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_fit_start
()
def
on_fit_end
(
self
):
return
self
.
model
.
on_fit_end
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
*
args
,
**
kwargs
):
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
*
args
,
**
kwargs
)
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
*
args
,
**
kwargs
):
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
*
args
,
**
kwargs
)
# Deprecated hooks in pytorch-lightning
def
on_epoch_start
(
self
):
return
self
.
model
.
on_epoch_start
()
def
on_epoch_end
(
self
):
return
self
.
model
.
on_epoch_end
()
# Reset the optimizer progress (only once at the very beginning)
self
.
_optimizer_progress
=
0
def
on_train_epoch_start
(
self
):
return
self
.
model
.
on_train_epoch_start
()
return
self
.
model
.
setup
(
stage
)
def
on_train_epoch_end
(
self
):
return
self
.
model
.
on_train_epoch_end
()
def
on_before_backward
(
self
,
loss
):
return
self
.
model
.
on_before_backward
(
loss
)
def
teardown
(
self
,
stage
=
None
):
return
self
.
model
.
teardown
(
stage
)
def
on_after_backward
(
self
):
return
self
.
model
.
on_after_backward
()
def
configure_gradient_clipping
(
self
,
optimizer
,
optimizer_idx
,
gradient_clip_val
=
None
,
gradient_clip_algorithm
=
None
):
return
self
.
model
.
configure_gradient_clipping
(
optimizer
,
optimizer_idx
,
gradient_clip_val
,
gradient_clip_algorithm
)
def
configure_architecture_optimizers
(
self
):
def
configure_architecture_optimizers
(
self
)
->
list
[
optim
.
Optimizer
]
|
optim
.
Optimizer
|
None
:
"""
Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here
if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers.
Returns
----------
arc_optimizers : list[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
-------
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return
None
def
call_lr_schedulers
(
self
,
batch_index
):
def
advance_optimization
(
self
,
loss
:
Any
,
batch_idx
:
int
,
gradient_clip_val
:
int
|
float
|
None
=
None
,
gradient_clip_algorithm
:
str
|
None
=
None
):
"""
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
by this class, you can use this function to make schedulers behave as they were automatically handled by the lightning trainer.
Run the optimizer defined in evaluators, when manual optimization is turned on.
Call this method when the model should be optimized.
To keep it as neat as possible, we only implement the basic ``zero_grad``, ``backward``, ``grad_clip``, and ``step`` here.
Many hooks and pre/post-processing are omitted.
Inherit this method if you need more advanced behavior.
The full optimizer step could be found
`here <https://github.com/Lightning-AI/lightning/blob/0e531283/src/pytorch_lightning/loops/optimization/optimizer_loop.py>`__.
We only implement part of the optimizer loop here.
Parameters
----------
batch_idx
: int
batch index
batch_idx: int
The current
batch index
.
"""
def
apply
(
lr_scheduler
):
# single scheduler is called every epoch
if
isinstance
(
lr_scheduler
,
_LRScheduler
):
if
self
.
trainer
.
is_last_batch
:
lr_scheduler
.
step
()
# lr_scheduler_config is called as configured
elif
isinstance
(
lr_scheduler
,
dict
):
interval
=
lr_scheduler
[
'interval'
]
frequency
=
lr_scheduler
[
'frequency'
]
if
(
interval
==
'step'
and
batch_index
%
frequency
==
0
)
or
\
(
interval
==
'epoch'
and
self
.
trainer
.
is_last_batch
and
(
self
.
trainer
.
current_epoch
+
1
)
%
frequency
==
0
):
lr_scheduler
[
'scheduler'
].
step
()
if
self
.
automatic_optimization
:
raise
ValueError
(
'This method should not be used when automatic optimization is turned on.'
)
if
self
.
trainer
.
optimizer_frequencies
:
warnings
.
warn
(
'optimizer_frequencies is not supported in NAS. It will be ignored.'
,
UserWarning
)
lr_schedulers
=
self
.
lr_schedulers
()
# Filter out optimizers for architecture parameters
optimizers
=
[
opt
for
opt
in
self
.
trainer
.
optimizers
if
not
getattr
(
opt
,
'is_arch_optimizer'
,
False
)]
if
isinstance
(
lr_schedulers
,
list
):
for
lr_scheduler
in
lr_schedulers
:
apply
(
lr_scheduler
)
else
:
apply
(
lr_schedulers
)
opt_idx
=
self
.
_optimizer_progress
%
len
(
optimizers
)
optimizer
=
optimizers
[
opt_idx
]
def
call_weight_optimizers
(
self
,
method
:
Literal
[
'step'
,
'zero_grad'
]):
# There should be many before/after hooks called here, but they are omitted in this implementation.
# 1. zero gradient
self
.
model
.
optimizer_zero_grad
(
self
.
trainer
.
current_epoch
,
batch_idx
,
optimizer
,
opt_idx
)
# 2. backward
self
.
manual_backward
(
loss
)
# 3. grad clip
self
.
model
.
configure_gradient_clipping
(
optimizer
,
opt_idx
,
gradient_clip_val
,
gradient_clip_algorithm
)
# 4. optimizer step
self
.
model
.
optimizer_step
(
self
.
trainer
.
current_epoch
,
batch_idx
,
optimizer
,
opt_idx
)
self
.
_optimizer_progress
+=
1
def
advance_lr_schedulers
(
self
,
batch_idx
:
int
):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Advance the learning rates, when manual optimization is turned on.
Parameter
s
----------
method : str
Method to call. Only ``step`` and ``zero_grad`` are
supported
now
.
The full implementation i
s
`here <https://github.com/Lightning-AI/lightning/blob/0e531283/src/pytorch_lightning/loops/epoch/training_epoch_loop.py>`__.
We only include a partial implementation here.
Advanced features like Reduce-lr-on-plateau are not
supported.
"""
def
apply_method
(
optimizer
,
method
):
if
method
==
'step'
:
optimizer
.
step
()
elif
method
==
'zero_grad'
:
optimizer
.
zero_grad
()
optimizers
=
self
.
weight_optimizers
()
if
optimizers
is
None
:
return
assert
isinstance
(
optimizers
,
list
),
'Did you forget to set use_pl_optimizers to true?'
if
len
(
self
.
frequencies
)
>
0
:
self
.
cur_optimizer_step
+=
1
if
self
.
frequencies
[
self
.
cur_optimizer_index
]
==
self
.
cur_optimizer_step
:
self
.
cur_optimizer_step
=
0
self
.
cur_optimizer_index
=
self
.
cur_optimizer_index
+
1
\
if
self
.
cur_optimizer_index
+
1
<
len
(
optimizers
)
\
else
0
apply_method
(
optimizers
[
self
.
cur_optimizer_index
],
method
)
else
:
for
optimizer
in
optimizers
:
apply_method
(
optimizer
,
method
)
if
self
.
automatic_optimization
:
raise
ValueError
(
'This method should not be used when automatic optimization is turned on.'
)
self
.
_advance_lr_schedulers_impl
(
batch_idx
,
'step'
)
if
self
.
trainer
.
is_last_batch
:
self
.
_advance_lr_schedulers_impl
(
batch_idx
,
'epoch'
)
def
_advance_lr_schedulers_impl
(
self
,
batch_idx
:
int
,
interval
:
str
):
current_idx
=
batch_idx
if
interval
==
'step'
else
self
.
trainer
.
current_epoch
current_idx
+=
1
# account for both batch and epoch starts from 0
try
:
# lightning >= 1.6
for
config
in
self
.
trainer
.
lr_scheduler_configs
:
scheduler
,
opt_idx
=
config
.
scheduler
,
config
.
opt_idx
if
config
.
reduce_on_plateau
:
warnings
.
warn
(
'Reduce-lr-on-plateau is not supported in NAS. It will be ignored.'
,
UserWarning
)
if
config
.
interval
==
interval
and
current_idx
%
config
.
frequency
==
0
:
self
.
model
.
lr_scheduler_step
(
cast
(
Any
,
scheduler
),
cast
(
int
,
opt_idx
),
None
)
except
AttributeError
:
# lightning < 1.6
for
lr_scheduler
in
self
.
trainer
.
lr_schedulers
:
if
lr_scheduler
[
'reduce_on_plateau'
]:
warnings
.
warn
(
'Reduce-lr-on-plateau is not supported in NAS. It will be ignored.'
,
UserWarning
)
if
lr_scheduler
[
'interval'
]
==
interval
and
current_idx
%
lr_scheduler
[
'frequency'
]:
lr_scheduler
[
'scheduler'
].
step
()
def
architecture_optimizers
(
self
)
->
list
[
Optimizer
]
|
Optimizer
|
None
:
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in :meth:`training_step`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Architecture optimizers defined in :meth:`configure_architecture_optimizers`. This will be None if there is no
architecture optimizers.
Get the optimizers configured in :meth:`configure_architecture_optimizers`.
"""
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
arc_opts
=
opts
[:
self
.
arc_optim_count
]
if
len
(
arc_opts
)
==
1
:
return
cast
(
Optimizer
,
arc_opts
[
0
])
return
cast
(
List
[
Optimizer
],
arc_opts
)
# If there is only 1 optimizer and it is the architecture optimizer
if
self
.
arc_optim_count
==
1
:
return
cast
(
Union
[
List
[
Optimizer
],
Optimizer
],
opts
)
return
None
optimizers
=
[
opt
for
opt
in
self
.
trainer
.
optimizers
if
getattr
(
opt
,
'is_arch_optimizer'
,
False
)]
if
not
optimizers
:
return
None
if
len
(
optimizers
)
==
1
:
return
optimizers
[
0
]
return
optimizers
def
weight_optimizers
(
self
)
->
list
[
Optimizer
]
|
Optimizer
|
None
:
"""
Get user optimizers from all optimizers. Use this to get user optimizers in :meth:`training_step`
.
# The following methods redirects the callbacks to inner module.
# It's not the complete list though.
# More methods can be added if needed
.
Returns
----------
opts : list[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
return
cast
(
List
[
Optimizer
],
opts
[
self
.
arc_optim_count
:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer
if
self
.
arc_optim_count
==
0
:
return
cast
(
Union
[
List
[
Optimizer
],
Optimizer
],
opts
)
return
None
def
on_train_start
(
self
):
return
self
.
model
.
on_train_start
()
def
on_train_end
(
self
):
return
self
.
model
.
on_train_end
()
def
on_fit_start
(
self
):
return
self
.
model
.
on_fit_start
()
def
on_fit_end
(
self
):
return
self
.
model
.
on_fit_end
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
*
args
,
**
kwargs
):
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
*
args
,
**
kwargs
)
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
*
args
,
**
kwargs
):
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
*
args
,
**
kwargs
)
def
on_train_epoch_start
(
self
):
return
self
.
model
.
on_train_epoch_start
()
def
on_train_epoch_end
(
self
):
return
self
.
model
.
on_train_epoch_end
()
def
on_before_backward
(
self
,
loss
):
return
self
.
model
.
on_before_backward
(
loss
)
def
on_after_backward
(
self
):
return
self
.
model
.
on_after_backward
()
nni/nas/oneshot/pytorch/differentiable.py
View file @
f77db747
...
...
@@ -9,7 +9,7 @@ import pytorch_lightning as pl
import
torch
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.base_lightning
import
BaseOneShotLightningModule
,
MANUAL_OPTIMIZATION_NOTE
,
MutationHook
,
no_default_hook
from
.supermodule.differentiable
import
(
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
MixedOpDifferentiablePolicy
,
GumbelSoftmax
,
...
...
@@ -28,6 +28,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
In both phases, ``training_step`` of the Lightning evaluator will be used.
The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet.
...
...
@@ -49,15 +50,20 @@ class DartsLightningModule(BaseOneShotLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
),
optimization_note
=
MANUAL_OPTIMIZATION_NOTE
)
__doc__
=
_darts_note
.
format
(
...
...
@@ -85,8 +91,10 @@ class DartsLightningModule(BaseOneShotLightningModule):
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0E-4
):
arc_learning_rate
:
float
=
3.0E-4
,
gradient_clip_val
:
float
|
None
=
None
):
self
.
arc_learning_rate
=
arc_learning_rate
self
.
gradient_clip_val
=
gradient_clip_val
super
().
__init__
(
inner_module
,
mutation_hooks
=
mutation_hooks
)
def
training_step
(
self
,
batch
,
batch_idx
):
...
...
@@ -108,33 +116,32 @@ class DartsLightningModule(BaseOneShotLightningModule):
if
isinstance
(
arc_step_loss
,
dict
):
arc_step_loss
=
arc_step_loss
[
'loss'
]
self
.
manual_backward
(
arc_step_loss
)
self
.
finalize_grad
()
arc_optim
.
step
()
# phase 2: model step
self
.
resample
()
self
.
call_weight_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
call_weight_optimizers
(
'step'
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
advance_optimization
(
w_step_loss
,
batch_idx
,
self
.
gradient_clip_val
)
self
.
call_lr_schedulers
(
batch_idx
)
# Update learning rates
self
.
advance_lr_schedulers
(
batch_idx
)
return
loss_and_metrics
self
.
log_dict
({
'prob/'
+
k
:
v
for
k
,
v
in
self
.
export_probs
().
items
()})
def
finalize_grad
(
self
):
# Note: This hook is currently kept for Proxyless NAS.
pass
return
loss_and_metrics
def
configure_architecture_optimizers
(
self
):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params
=
[]
for
m
in
self
.
nas_modules
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
# type: ignore
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
set
(
ctrl_params
)),
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
# Follow the hyper-parameters used in
# https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/architect.py#L17
params
=
list
(
set
(
ctrl_params
))
if
not
params
:
raise
ValueError
(
'No architecture parameters found. Nothing to search.'
)
ctrl_optim
=
torch
.
optim
.
Adam
(
params
,
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
return
ctrl_optim
...
...
@@ -153,13 +160,20 @@ class ProxylessLightningModule(DartsLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
optimization_note
=
MANUAL_OPTIMIZATION_NOTE
)
__doc__
=
_proxyless_note
.
format
(
module_notes
=
'This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.'
,
...
...
@@ -176,10 +190,6 @@ class ProxylessLightningModule(DartsLightningModule):
# FIXME: no support for mixed operation currently
return
hooks
def
finalize_grad
(
self
):
for
m
in
self
.
nas_modules
:
m
.
finalize_grad
()
# type: ignore
class
GumbelDartsLightningModule
(
DartsLightningModule
):
_gumbel_darts_note
=
"""
...
...
@@ -207,6 +217,8 @@ class GumbelDartsLightningModule(DartsLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
...
...
@@ -216,13 +228,17 @@ class GumbelDartsLightningModule(DartsLightningModule):
use_temp_anneal : bool
If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details.
Default is false.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
),
optimization_note
=
MANUAL_OPTIMIZATION_NOTE
)
def
mutate_kwargs
(
self
):
...
...
@@ -235,22 +251,25 @@ class GumbelDartsLightningModule(DartsLightningModule):
def
__init__
(
self
,
inner_module
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0e-4
,
gradient_clip_val
:
float
|
None
=
None
,
gumbel_temperature
:
float
=
1.
,
use_temp_anneal
:
bool
=
False
,
min_temp
:
float
=
.
33
):
super
().
__init__
(
inner_module
,
mutation_hooks
,
arc_learning_rate
=
arc_learning_rate
)
super
().
__init__
(
inner_module
,
mutation_hooks
,
arc_learning_rate
=
arc_learning_rate
,
gradient_clip_val
=
gradient_clip_val
)
self
.
temp
=
gumbel_temperature
self
.
init_temp
=
gumbel_temperature
self
.
use_temp_anneal
=
use_temp_anneal
self
.
min_temp
=
min_temp
def
on_train_epoch_
end
(
self
):
def
on_train_epoch_
start
(
self
):
if
self
.
use_temp_anneal
:
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
self
.
log
(
'gumbel_temperature'
,
self
.
temp
)
for
module
in
self
.
nas_modules
:
if
hasattr
(
module
,
'_softmax'
):
module
.
_softmax
.
t
emp
=
self
.
temp
# type: ignore
if
hasattr
(
module
,
'_softmax'
)
and
isinstance
(
module
,
GumbelSoftmax
)
:
module
.
_softmax
.
t
au
=
self
.
temp
# type: ignore
return
self
.
model
.
on_train_epoch_
end
()
return
self
.
model
.
on_train_epoch_
start
()
nni/nas/oneshot/pytorch/enas.py
View file @
f77db747
...
...
@@ -94,11 +94,11 @@ class ReinforceController(nn.Module):
field
.
name
:
nn
.
Embedding
(
field
.
total
,
self
.
lstm_size
)
for
field
in
fields
})
def
resample
(
self
):
def
resample
(
self
,
return_prob
=
False
):
self
.
_initialize
()
result
=
dict
()
for
field
in
self
.
fields
:
result
[
field
.
name
]
=
self
.
_sample_single
(
field
)
result
[
field
.
name
]
=
self
.
_sample_single
(
field
,
return_prob
=
return_prob
)
return
result
def
_initialize
(
self
):
...
...
@@ -116,7 +116,7 @@ class ReinforceController(nn.Module):
def
_lstm_next_step
(
self
):
self
.
_h
,
self
.
_c
=
self
.
lstm
(
self
.
_inputs
,
(
self
.
_h
,
self
.
_c
))
def
_sample_single
(
self
,
field
):
def
_sample_single
(
self
,
field
,
return_prob
):
self
.
_lstm_next_step
()
logit
=
self
.
soft
[
field
.
name
](
self
.
_h
[
-
1
])
if
self
.
temperature
is
not
None
:
...
...
@@ -124,10 +124,12 @@ class ReinforceController(nn.Module):
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
if
field
.
choose_one
:
sampled_dist
=
F
.
softmax
(
logit
,
dim
=-
1
)
sampled
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
sampled
)
self
.
_inputs
=
self
.
embedding
[
field
.
name
](
sampled
)
else
:
sampled_dist
=
torch
.
sigmoid
(
logit
)
logit
=
logit
.
view
(
-
1
,
1
)
logit
=
torch
.
cat
([
-
logit
,
logit
],
1
)
# pylint: disable=invalid-unary-operand-type
sampled
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
...
...
@@ -147,4 +149,7 @@ class ReinforceController(nn.Module):
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
if
len
(
sampled
)
==
1
:
sampled
=
sampled
[
0
]
if
return_prob
:
return
sampled_dist
.
flatten
().
detach
().
cpu
().
numpy
().
tolist
()
return
sampled
nni/nas/oneshot/pytorch/sampling.py
View file @
f77db747
...
...
@@ -5,14 +5,14 @@
from
__future__
import
annotations
import
warnings
from
typing
import
Any
from
typing
import
Any
,
cast
import
pytorch_lightning
as
pl
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.base_lightning
import
MANUAL_OPTIMIZATION_NOTE
,
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
,
NATIVE_SUPPORTED_OP_NAMES
from
.supermodule.sampling
import
(
PathSamplingInput
,
PathSamplingLayer
,
MixedOpPathSamplingPolicy
,
...
...
@@ -37,6 +37,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
This strategy assumes inner evaluator has set
`automatic optimization <https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html>`__ to true.
Parameters
----------
{{module_params}}
...
...
@@ -73,9 +76,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
}
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
*
args
,
**
kwargs
):
self
.
resample
()
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
return
self
.
model
.
training_step
(
*
args
,
**
kwargs
)
def
export
(
self
)
->
dict
[
str
,
Any
]:
"""
...
...
@@ -115,6 +118,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
...
...
@@ -133,6 +138,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
before updating the weights of RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
log_prob_every_n_step : int
Log the probability of choices every N steps. Useful for visualization and debugging.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
...
...
@@ -141,11 +148,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Otherwise it raises an exception indicating multiple metrics are found.
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
),
optimization_note
=
MANUAL_OPTIMIZATION_NOTE
)
__doc__
=
_enas_note
.
format
(
module_notes
=
'``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.
utils.ConcatenateTrainValDatal
oader`.'
,
module_notes
=
'``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.
pytorch.dataloader.ConcatL
oader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
...
...
@@ -162,6 +170,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
baseline_decay
:
float
=
.
999
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_grad_clip
:
float
=
0
,
log_prob_every_n_step
:
int
=
10
,
reward_metric_name
:
str
|
None
=
None
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
):
super
().
__init__
(
inner_module
,
mutation_hooks
)
...
...
@@ -181,33 +190,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
self
.
baseline
=
0.
self
.
ctrl_steps_aggregate
=
ctrl_steps_aggregate
self
.
ctrl_grad_clip
=
ctrl_grad_clip
self
.
log_prob_every_n_step
=
log_prob_every_n_step
self
.
reward_metric_name
=
reward_metric_name
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
def
training_step
(
self
,
batch_packed
,
batch_idx
):
# The received batch is a tuple of (data, "train" | "val")
batch
,
mode
=
batch_packed
if
mode
==
'train'
:
# train model params
with
torch
.
no_grad
():
self
.
resample
()
self
.
call_weight_optimizers
(
'zero_grad'
)
step_output
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
step_output
[
'loss'
]
\
if
isinstance
(
step_output
,
dict
)
else
step_output
self
.
manual_backward
(
w_step_loss
)
self
.
call_weight_optimizers
(
'step'
)
w_step_loss
=
step_output
[
'loss'
]
if
isinstance
(
step_output
,
dict
)
else
step_output
self
.
advance_optimization
(
w_step_loss
,
batch_idx
)
else
:
# train ENAS agent
arc_opt
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
self
.
resample
()
# Run a sample to retrieve the reward
self
.
resample
()
step_output
=
self
.
model
.
validation_step
(
batch
,
batch_idx
)
# use the default metric of self.model as reward function
...
...
@@ -218,11 +223,13 @@ class EnasLightningModule(RandomSamplingLightningModule):
if
metric_name
not
in
self
.
trainer
.
callback_metrics
:
raise
KeyError
(
f
'Model reported metrics should contain a ``
{
metric_name
}
`` key but '
f
'found multiple (or zero) metrics without default:
{
list
(
self
.
trainer
.
callback_metrics
.
keys
())
}
. '
f
'Try to use self.log to report metrics with the specified key ``
{
metric_name
}
`` in validation_step, '
'and remember to set on_step=True.'
)
f
'Please try to set ``reward_metric_name`` to be one of the keys listed above. '
f
'If it is not working use self.log to report metrics with the specified key ``
{
metric_name
}
`` '
'in validation_step, and remember to set on_step=True.'
)
metric
=
self
.
trainer
.
callback_metrics
[
metric_name
]
reward
:
float
=
metric
.
item
()
# Compute the loss and run back propagation
if
self
.
entropy_weight
:
reward
=
reward
+
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
...
...
@@ -236,11 +243,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
if
(
batch_idx
+
1
)
%
self
.
ctrl_steps_aggregate
==
0
:
if
self
.
ctrl_grad_clip
>
0
:
nn
.
utils
.
clip_grad_norm_
(
self
.
controller
.
parameters
(),
self
.
ctrl_grad_clip
)
# Update the controller and zero out its gradients
arc_opt
=
cast
(
optim
.
Optimizer
,
self
.
architecture_optimizers
())
arc_opt
.
step
()
arc_opt
.
zero_grad
()
self
.
advance_lr_schedulers
(
batch_idx
)
if
(
batch_idx
+
1
)
%
self
.
log_prob_every_n_step
==
0
:
with
torch
.
no_grad
():
self
.
log_dict
({
'prob/'
+
k
:
v
for
k
,
v
in
self
.
export_probs
().
items
()})
return
step_output
def
on_train_epoch_start
(
self
):
# Always zero out the gradients of ENAS controller at the beginning of epochs.
arc_opt
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
return
self
.
model
.
on_train_epoch_start
()
def
resample
(
self
):
"""Resample the architecture with ENAS controller."""
sample
=
self
.
controller
.
resample
()
...
...
@@ -249,6 +274,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
module
.
resample
(
memo
=
result
)
return
result
def
export_probs
(
self
):
"""Export the probability from ENAS controller directly."""
sample
=
self
.
controller
.
resample
(
return_prob
=
True
)
result
=
self
.
_interpret_controller_probability_result
(
sample
)
for
module
in
self
.
nas_modules
:
module
.
resample
(
memo
=
result
)
return
result
def
export
(
self
):
"""Run one more inference of ENAS controller."""
self
.
controller
.
eval
()
...
...
@@ -261,3 +294,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
for
key
in
list
(
sample
.
keys
()):
sample
[
key
]
=
space_spec
[
key
].
values
[
sample
[
key
]]
return
sample
def
_interpret_controller_probability_result
(
self
,
sample
:
dict
[
str
,
list
[
float
]])
->
dict
[
str
,
Any
]:
"""Convert ``{label: [prob1, prob2, prob3]} to ``{label/choice: prob}``"""
space_spec
=
self
.
search_space_spec
()
result
=
{}
for
key
in
list
(
sample
.
keys
()):
if
len
(
space_spec
[
key
].
values
)
!=
len
(
sample
[
key
]):
raise
ValueError
(
f
'Expect
{
space_spec
[
key
].
values
}
to be of the same length as
{
sample
[
key
]
}
'
)
for
value
,
weight
in
zip
(
space_spec
[
key
].
values
,
sample
[
key
]):
result
[
f
'
{
key
}
/
{
value
}
'
]
=
weight
return
result
nni/nas/oneshot/pytorch/supermodule/_valuechoice_utils.py
View file @
f77db747
...
...
@@ -168,11 +168,11 @@ def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence
assert
len
(
items
)
==
len
(
weights
)
>
0
elem
=
items
[
0
]
unsupported_msg
=
f
'Unsupported element type in weighted sum:
{
type
(
elem
)
}
. Value is:
{
elem
}
'
unsupported_msg
=
'Unsupported element type in weighted sum: {}. Value is: {}'
if
isinstance
(
elem
,
str
):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise
TypeError
(
unsupported_msg
)
raise
TypeError
(
unsupported_msg
.
format
(
type
(
elem
),
elem
)
)
try
:
if
isinstance
(
elem
,
(
torch
.
Tensor
,
np
.
ndarray
,
float
,
int
,
np
.
number
)):
...
...
nni/nas/oneshot/pytorch/supermodule/base.py
View file @
f77db747
...
...
@@ -56,6 +56,17 @@ class BaseSuperNetModule(nn.Module):
"""
raise
NotImplementedError
()
def
export_probs
(
self
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""
Export the probability / logits of every choice got chosen.
Parameters
----------
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise
NotImplementedError
()
def
search_space_spec
(
self
)
->
dict
[
str
,
ParameterSpec
]:
"""
Space specification (sample points).
...
...
nni/nas/oneshot/pytorch/supermodule/differentiable.py
View file @
f77db747
...
...
@@ -104,6 +104,13 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
())]}
def
export_probs
(
self
,
memo
):
if
any
(
k
.
startswith
(
self
.
label
+
'/'
)
for
k
in
memo
):
return
{}
# nothing new
weights
=
self
.
_softmax
(
self
.
_arch_alpha
).
cpu
().
tolist
()
ret
=
{
f
'
{
self
.
label
}
/
{
name
}
'
:
value
for
name
,
value
in
zip
(
self
.
op_names
,
weights
)}
return
ret
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
True
,
size
=
len
(
self
.
op_names
))}
...
...
@@ -117,7 +124,8 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
if
len
(
alpha
)
!=
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
module
.
label
}
conflict:
{
len
(
alpha
)
}
vs.
{
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# this can be reinitialized later
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# the numbers in the parameter can be reinitialized later
memo
[
module
.
label
]
=
alpha
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
list
(
module
.
named_children
()),
alpha
,
softmax
,
module
.
label
)
...
...
@@ -208,6 +216,13 @@ class DifferentiableMixedInput(BaseSuperNetModule):
chosen
=
chosen
[
0
]
return
{
self
.
label
:
chosen
}
def
export_probs
(
self
,
memo
):
if
any
(
k
.
startswith
(
self
.
label
+
'/'
)
for
k
in
memo
):
return
{}
# nothing new
weights
=
self
.
_softmax
(
self
.
_arch_alpha
).
cpu
().
tolist
()
ret
=
{
f
'
{
self
.
label
}
/
{
index
}
'
:
value
for
index
,
value
in
enumerate
(
weights
)}
return
ret
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
list
(
range
(
self
.
n_candidates
)),
...
...
@@ -225,7 +240,8 @@ class DifferentiableMixedInput(BaseSuperNetModule):
if
len
(
alpha
)
!=
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
module
.
label
}
conflict:
{
len
(
alpha
)
}
vs.
{
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# this can be reinitialized later
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# the numbers in the parameter can be reinitialized later
memo
[
module
.
label
]
=
alpha
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
alpha
,
softmax
,
module
.
label
)
...
...
@@ -284,6 +300,7 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
raise
ValueError
(
f
'Architecture parameter size of same label
{
name
}
conflict:
{
len
(
alpha
)
}
vs.
{
spec
.
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
memo
[
name
]
=
alpha
operation
.
_arch_alpha
[
name
]
=
alpha
operation
.
parameters
=
functools
.
partial
(
self
.
parameters
,
module
=
operation
)
# bind self
...
...
@@ -321,6 +338,16 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
export_probs
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
]):
"""Export the weight for every leaf value choice."""
ret
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
any
(
k
.
startswith
(
name
+
'/'
)
for
k
in
memo
):
continue
weights
=
operation
.
_softmax
(
operation
.
_arch_alpha
[
name
]).
cpu
().
tolist
()
# type: ignore
ret
.
update
({
f
'
{
name
}
/
{
value
}
'
:
weight
for
value
,
weight
in
zip
(
spec
.
values
,
weights
)})
return
ret
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
dict
[
Any
,
float
]
|
Any
:
if
name
in
operation
.
mutable_arguments
:
weights
:
dict
[
str
,
torch
.
Tensor
]
=
{
...
...
@@ -360,6 +387,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
raise
ValueError
(
f
'Architecture parameter size of same label
{
name
}
conflict:
{
len
(
alpha
)
}
vs.
{
spec
.
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
memo
[
name
]
=
alpha
self
.
_arch_alpha
[
name
]
=
alpha
def
resample
(
self
,
memo
):
...
...
@@ -376,6 +404,16 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
export_probs
(
self
,
memo
):
"""Export the weight for every leaf value choice."""
ret
=
{}
for
name
,
spec
in
self
.
search_space_spec
().
items
():
if
any
(
k
.
startswith
(
name
+
'/'
)
for
k
in
memo
):
continue
weights
=
self
.
_softmax
(
self
.
_arch_alpha
[
name
]).
cpu
().
tolist
()
ret
.
update
({
f
'
{
name
}
/
{
value
}
'
:
weight
for
value
,
weight
in
zip
(
spec
.
values
,
weights
)})
return
ret
def
search_space_spec
(
self
):
return
self
.
_space_spec
...
...
@@ -427,6 +465,8 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
class
DifferentiableMixedCell
(
PathSamplingCell
):
"""Implementation of Cell under differentiable context.
Similar to PathSamplingCell, this cell only handles cells of specific kinds (e.g., with loose end).
An architecture parameter is created on each edge of the full-connected graph.
"""
...
...
@@ -450,13 +490,21 @@ class DifferentiableMixedCell(PathSamplingCell):
op
=
cast
(
List
[
Dict
[
str
,
nn
.
Module
]],
self
.
ops
[
i
-
self
.
num_predecessors
])[
j
]
if
edge_label
in
memo
:
alpha
=
memo
[
edge_label
]
if
len
(
alpha
)
!=
len
(
op
):
raise
ValueError
(
f
'Architecture parameter size of same label
{
edge_label
}
conflict: '
f
'
{
len
(
alpha
)
}
vs.
{
len
(
op
)
}
'
if
len
(
alpha
)
!=
len
(
op
)
+
1
:
if
len
(
alpha
)
!=
len
(
op
):
raise
ValueError
(
f
'Architecture parameter size of same label
{
edge_label
}
conflict: '
f
'
{
len
(
alpha
)
}
vs.
{
len
(
op
)
}
'
)
warnings
.
warn
(
f
'Architecture parameter size
{
len
(
alpha
)
}
is not same as expected:
{
len
(
op
)
+
1
}
. '
'This is likely due to the label being shared by a LayerChoice inside the cell and outside.'
,
UserWarning
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
op
))
*
1E-3
)
# +1 to emulate the input choice.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
op
)
+
1
)
*
1E-3
)
memo
[
edge_label
]
=
alpha
self
.
_arch_alpha
[
edge_label
]
=
alpha
self
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
...
...
@@ -465,18 +513,32 @@ class DifferentiableMixedCell(PathSamplingCell):
"""Differentiable doesn't need to resample."""
return
{}
def
export_probs
(
self
,
memo
):
"""When export probability, we follow the structure in arch alpha."""
ret
=
{}
for
name
,
parameter
in
self
.
_arch_alpha
.
items
():
if
any
(
k
.
startswith
(
name
+
'/'
)
for
k
in
memo
):
continue
weights
=
self
.
_softmax
(
parameter
).
cpu
().
tolist
()
ret
.
update
({
f
'
{
name
}
/
{
value
}
'
:
weight
for
value
,
weight
in
zip
(
self
.
op_names
,
weights
)})
return
ret
def
export
(
self
,
memo
):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported
=
{}
for
i
in
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
):
# If label already exists, no need to re-export.
if
all
(
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
in
memo
and
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
in
memo
for
k
in
range
(
self
.
num_ops_per_node
)):
continue
# Tuple of (weight, input_index, op_name)
all_weights
:
list
[
tuple
[
float
,
int
,
str
]]
=
[]
for
j
in
range
(
i
):
for
k
,
name
in
enumerate
(
self
.
op_names
):
# The last appended weight is automatically skipped in export.
all_weights
.
append
((
float
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
][
k
].
item
()),
j
,
name
,
...
...
@@ -497,7 +559,7 @@ class DifferentiableMixedCell(PathSamplingCell):
all_weights
=
[
all_weights
[
k
]
for
k
in
first_occurrence_index
]
+
\
[
w
for
j
,
w
in
enumerate
(
all_weights
)
if
j
not
in
first_occurrence_index
]
_logger
.
info
(
'Sorted weights in differentiable cell export (node %d): %s'
,
i
,
all_weights
)
_logger
.
info
(
'Sorted weights in differentiable cell export (
%s cell,
node %d): %s'
,
self
.
label
,
i
,
all_weights
)
for
k
in
range
(
self
.
num_ops_per_node
):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
...
...
@@ -515,7 +577,11 @@ class DifferentiableMixedCell(PathSamplingCell):
for
j
in
range
(
i
):
# for every previous tensors
op_results
=
torch
.
stack
([
op
(
states
[
j
])
for
op
in
ops
[
j
].
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
# (-1, 1, 1, 1, 1, ...)
op_weights
=
self
.
_softmax
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
])
if
len
(
op_weights
)
==
len
(
op_results
)
+
1
:
# concatenate with a zero operation, indicating this path is not chosen at all.
op_results
=
torch
.
cat
((
op_results
,
torch
.
zeros_like
(
op_results
[:
1
])),
0
)
edge_sum
=
torch
.
sum
(
op_results
*
self
.
_softmax
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
]).
view
(
*
alpha_shape
),
0
)
current_state
.
append
(
edge_sum
)
...
...
nni/nas/oneshot/pytorch/supermodule/operation.py
View file @
f77db747
...
...
@@ -71,6 +71,10 @@ class MixedOperationSamplingPolicy:
"""The handler of :meth:`MixedOperation.export`."""
raise
NotImplementedError
()
def
export_probs
(
self
,
operation
:
'MixedOperation'
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.export_probs`."""
raise
NotImplementedError
()
def
forward_argument
(
self
,
operation
:
'MixedOperation'
,
name
:
str
)
->
Any
:
"""Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value.
...
...
@@ -162,6 +166,10 @@ class MixedOperation(BaseSuperNetModule):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return
self
.
sampling_policy
.
resample
(
self
,
memo
)
def
export_probs
(
self
,
memo
):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export_probs`."""
return
self
.
sampling_policy
.
export_probs
(
self
,
memo
)
def
export
(
self
,
memo
):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return
self
.
sampling_policy
.
export
(
self
,
memo
)
...
...
nni/nas/oneshot/pytorch/supermodule/proxyless.py
View file @
f77db747
...
...
@@ -11,7 +11,7 @@ The support remains limited. Known limitations include:
from
__future__
import
annotations
from
typing
import
cast
from
typing
import
Any
,
Tuple
,
Union
,
cast
import
torch
import
torch.nn
as
nn
...
...
@@ -21,28 +21,115 @@ from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput
__all__
=
[
'ProxylessMixedLayer'
,
'ProxylessMixedInput'
]
class
_ArchGradientFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
binary_gates
,
run_func
,
backward_func
):
ctx
.
run_func
=
run_func
ctx
.
backward_func
=
backward_func
def
_detach_tensor
(
tensor
:
Any
)
->
Any
:
"""Recursively detach all the tensors."""
if
isinstance
(
tensor
,
(
list
,
tuple
)):
return
tuple
(
_detach_tensor
(
t
)
for
t
in
tensor
)
elif
isinstance
(
tensor
,
dict
):
return
{
k
:
_detach_tensor
(
v
)
for
k
,
v
in
tensor
.
items
()}
elif
isinstance
(
tensor
,
torch
.
Tensor
):
return
tensor
.
detach
()
else
:
return
tensor
detached_x
=
x
.
detach
()
detached_x
.
requires_grad
=
x
.
requires_grad
with
torch
.
enable_grad
():
output
=
run_func
(
detached_x
)
ctx
.
save_for_backward
(
detached_x
,
output
)
return
output
.
data
@
staticmethod
def
backward
(
ctx
,
grad_output
):
detached_x
,
output
=
ctx
.
saved_tensors
def
_iter_tensors
(
tensor
:
Any
)
->
Any
:
"""Recursively iterate over all the tensors.
grad_x
=
torch
.
autograd
.
grad
(
output
,
detached_x
,
grad_output
,
only_inputs
=
True
)
# compute gradients w.r.t. binary_gates
binary_grads
=
ctx
.
backward_func
(
detached_x
.
data
,
output
.
data
,
grad_output
.
data
)
This is kept for complex outputs (like dicts / lists).
However, complex outputs are not supported by PyTorch backward hooks yet.
"""
if
isinstance
(
tensor
,
torch
.
Tensor
):
yield
tensor
elif
isinstance
(
tensor
,
(
list
,
tuple
)):
for
t
in
tensor
:
yield
from
_iter_tensors
(
t
)
elif
isinstance
(
tensor
,
dict
):
for
t
in
tensor
.
values
():
yield
from
_iter_tensors
(
t
)
def
_pack_as_tuple
(
tensor
:
Any
)
->
tuple
:
"""Return a tuple of tensor with only one element if tensor it's not a tuple."""
if
isinstance
(
tensor
,
(
tuple
,
list
)):
return
tuple
(
tensor
)
return
(
tensor
,)
def
element_product_sum
(
tensor1
:
tuple
[
torch
.
Tensor
,
...],
tensor2
:
tuple
[
torch
.
Tensor
,
...])
->
torch
.
Tensor
:
"""Compute the sum of all the element-wise product."""
assert
len
(
tensor1
)
==
len
(
tensor2
),
'The number of tensors must be the same.'
# Skip zero gradients
ret
=
[
torch
.
sum
(
t1
*
t2
)
for
t1
,
t2
in
zip
(
tensor1
,
tensor2
)
if
t1
is
not
None
and
t2
is
not
None
]
if
not
ret
:
return
torch
.
tensor
(
0
)
if
len
(
ret
)
==
1
:
return
ret
[
0
]
return
cast
(
torch
.
Tensor
,
sum
(
ret
))
class
ProxylessContext
:
def
__init__
(
self
,
arch_alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
)
->
None
:
self
.
arch_alpha
=
arch_alpha
self
.
softmax
=
softmax
# When a layer is called multiple times, the inputs and outputs are saved in order.
# In backward propagation, we assume that they are used in the reversed order.
self
.
layer_input
:
list
[
Any
]
=
[]
self
.
layer_output
:
list
[
Any
]
=
[]
self
.
layer_sample_idx
:
list
[
int
]
=
[]
def
clear_context
(
self
)
->
None
:
self
.
layer_input
=
[]
self
.
layer_output
=
[]
self
.
layer_sample_idx
=
[]
def
save_forward_context
(
self
,
layer_input
:
Any
,
layer_output
:
Any
,
layer_sample_idx
:
int
):
self
.
layer_input
.
append
(
_detach_tensor
(
layer_input
))
self
.
layer_output
.
append
(
_detach_tensor
(
layer_output
))
self
.
layer_sample_idx
.
append
(
layer_sample_idx
)
def
backward_hook
(
self
,
module
:
nn
.
Module
,
grad_input
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
grad_output
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
])
->
None
:
# binary_grads is the gradient of binary gates.
# Binary gates is a one-hot tensor where 1 is on the sampled index, and others are 0.
# By chain rule, it's gradient is grad_output times the layer_output (of the corresponding path).
binary_grads
=
torch
.
zeros_like
(
self
.
arch_alpha
)
# Retrieve the layer input/output in reverse order.
if
not
self
.
layer_input
:
raise
ValueError
(
'Unexpected backward call. The saved context is empty.'
)
layer_input
=
self
.
layer_input
.
pop
()
layer_output
=
self
.
layer_output
.
pop
()
layer_sample_idx
=
self
.
layer_sample_idx
.
pop
()
return
grad_x
[
0
],
binary_grads
,
None
,
None
with
torch
.
no_grad
():
# Compute binary grads.
for
k
in
range
(
len
(
binary_grads
)):
if
k
!=
layer_sample_idx
:
args
,
kwargs
=
layer_input
out_k
=
module
.
forward_path
(
k
,
*
args
,
**
kwargs
)
# type: ignore
else
:
out_k
=
layer_output
# FIXME: One limitation here is that out_k can't be complex objects like dict.
# I think it's also a limitation of backward hook.
binary_grads
[
k
]
=
element_product_sum
(
_pack_as_tuple
(
out_k
),
# In case out_k is a single tensor
_pack_as_tuple
(
grad_output
)
)
# Compute the gradient of the arch_alpha, based on binary_grads.
if
self
.
arch_alpha
.
grad
is
None
:
self
.
arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
arch_alpha
)
probs
=
self
.
softmax
(
self
.
arch_alpha
)
for
i
in
range
(
len
(
self
.
arch_alpha
)):
for
j
in
range
(
len
(
self
.
arch_alpha
)):
# Arch alpha's gradients are accumulated for all backwards through this layer.
self
.
arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
class
ProxylessMixedLayer
(
DifferentiableMixedLayer
):
...
...
@@ -50,46 +137,32 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
It resamples a single-path every time, rather than go through the softmax.
"""
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
_arch_parameter_names
=
[
'_arch_alpha'
]
def
__init__
(
self
,
paths
:
list
[
tuple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
paths
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
paths
))
*
1E-3
)
# Binary gates should be created here, but it's not because it's never used in the forward pass.
# self._binary_gates = nn.Parameter(torch.zeros(len(paths)))
# like sampling-based methods, it has a ``_sampled``.
self
.
_sampled
:
str
|
None
=
None
self
.
_sample_idx
:
int
|
None
=
None
# arch_alpha could be shared by multiple layers,
# but binary_gates is owned by the current layer.
self
.
ctx
=
ProxylessContext
(
alpha
,
softmax
)
self
.
register_full_backward_hook
(
self
.
ctx
.
backward_hook
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
run_function
(
ops
,
active_id
,
**
kwargs
):
def
forward
(
_x
):
return
ops
[
active_id
](
_x
,
**
kwargs
)
return
forward
def
backward_function
(
ops
,
active_id
,
binary_gates
,
**
kwargs
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
len
(
ops
)):
if
k
!=
active_id
:
out_k
=
ops
[
k
](
_x
.
data
,
**
kwargs
)
else
:
out_k
=
_output
.
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
assert
len
(
args
)
==
1
,
'ProxylessMixedLayer only supports exactly one input argument.'
x
=
args
[
0
]
assert
self
.
_sampled
is
not
None
,
'Need to call resample() before running fprop.'
list_ops
=
[
getattr
(
self
,
op
)
for
op
in
self
.
op_names
]
return
_ArchGradientFunction
.
apply
(
x
,
self
.
_binary_gates
,
run_function
(
list_ops
,
self
.
_sample_idx
,
**
kwargs
),
backward_function
(
list_ops
,
self
.
_sample_idx
,
self
.
_binary_gates
,
**
kwargs
)
)
"""Forward pass of one single path."""
if
self
.
_sample_idx
is
None
:
raise
RuntimeError
(
'resample() needs to be called before fprop.'
)
output
=
self
.
forward_path
(
self
.
_sample_idx
,
*
args
,
**
kwargs
)
self
.
ctx
.
save_forward_context
((
args
,
kwargs
),
output
,
self
.
_sample_idx
)
return
output
def
forward_path
(
self
,
index
,
*
args
,
**
kwargs
):
return
getattr
(
self
,
self
.
op_names
[
index
])(
*
args
,
**
kwargs
)
def
resample
(
self
,
memo
):
"""Sample one path based on alpha if label is not found in memo."""
...
...
@@ -101,66 +174,37 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self
.
_sample_idx
=
int
(
torch
.
multinomial
(
probs
,
1
)[
0
].
item
())
self
.
_sampled
=
self
.
op_names
[
self
.
_sample_idx
]
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
self
.
_sample_idx
]
=
1.0
self
.
ctx
.
clear_context
()
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
())]}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
for
i
in
range
(
len
(
self
.
_arch_alpha
)):
for
j
in
range
(
len
(
self
.
_arch_alpha
)):
self
.
_arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
class
ProxylessMixedInput
(
DifferentiableMixedInput
):
"""Proxyless version of differentiable input choice.
See :class:`ProxylessLayer
Choice
` for implementation details.
See :class:`Proxyless
Mixed
Layer` for implementation details.
"""
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
n_candidates
,
n_chosen
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
n_candidates
)
*
1E-3
)
# We only support choosing a particular one here.
# Nevertheless, we rank the score and export the tops in export.
self
.
_sampled
:
int
|
None
=
None
self
.
ctx
=
ProxylessContext
(
alpha
,
softmax
)
self
.
register_full_backward_hook
(
self
.
ctx
.
backward_hook
)
def
forward
(
self
,
inputs
):
def
run_function
(
active_sample
):
return
lambda
x
:
x
[
active_sample
]
def
backward_function
(
binary_gates
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
self
.
n_candidates
):
out_k
=
_x
[
k
].
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
inputs
=
torch
.
stack
(
inputs
,
0
)
assert
self
.
_sampled
is
not
None
,
'Need to call resample() before running fprop.'
return
_ArchGradientFunction
.
apply
(
inputs
,
self
.
_binary_gates
,
run_function
(
self
.
_sampled
),
backward_function
(
self
.
_binary_gates
)
)
"""Choose one single input."""
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'resample() needs to be called before fprop.'
)
output
=
self
.
forward_path
(
self
.
_sampled
,
inputs
)
self
.
ctx
.
save_forward_context
(((
inputs
,),
{}),
output
,
self
.
_sampled
)
return
output
def
forward_path
(
self
,
index
,
inputs
):
return
inputs
[
index
]
def
resample
(
self
,
memo
):
"""Sample one path based on alpha if label is not found in memo."""
...
...
@@ -171,27 +215,6 @@ class ProxylessMixedInput(DifferentiableMixedInput):
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sampled
=
int
(
sample
)
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
cast
(
int
,
self
.
_sampled
)]
=
1.0
self
.
ctx
.
clear_context
()
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
torch
.
argmax
(
self
.
_arch_alpha
).
item
()}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
for
i
in
range
(
self
.
n_candidates
):
for
j
in
range
(
self
.
n_candidates
):
self
.
_arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
nni/nas/oneshot/pytorch/supermodule/sampling.py
View file @
f77db747
...
...
@@ -169,7 +169,7 @@ class PathSamplingInput(BaseSuperNetModule):
class
MixedOpPathSamplingPolicy
(
MixedOperationSamplingPolicy
):
"""Implement
e
s the path sampling in mixed operation.
"""Implements the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices".
...
...
@@ -388,6 +388,10 @@ class PathSamplingCell(BaseSuperNetModule):
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
"""
Mutate only handles cells of specific configurations (e.g., with loose end).
Fallback to the default mutate if the cell is not handled here.
"""
if
isinstance
(
module
,
Cell
):
op_factory
=
None
# not all the cells need to be replaced
if
module
.
op_candidates_factory
is
not
None
:
...
...
test/algo/nas/test_oneshot.py
View file @
f77db747
...
...
@@ -5,6 +5,7 @@ import pytorch_lightning as pl
import
pytest
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
from
torch
import
nn
from
torch.utils.data
import
Dataset
,
RandomSampler
import
nni
...
...
@@ -13,7 +14,11 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
Regression
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ValueChoice
from
nni.retiarii.oneshot.pytorch
import
DartsLightningModule
from
nni.retiarii.strategy
import
BaseStrategy
from
pytorch_lightning
import
LightningModule
,
Trainer
from
.test_oneshot_utils
import
RandomDataset
pytestmark
=
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
...
...
@@ -338,17 +343,49 @@ def test_gumbel_darts():
_test_strategy
(
strategy
.
GumbelDARTS
())
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--exp'
,
type
=
str
,
default
=
'all'
,
metavar
=
'E'
,
help
=
'experiment to run, default = all'
)
args
=
parser
.
parse_args
()
def
test_optimizer_lr_scheduler
():
learning_rates
=
[]
if
args
.
exp
==
'all'
:
test_darts
()
test_proxyless
()
test_enas
()
test_random
()
test_gumbel_darts
()
else
:
globals
()[
f
'test_
{
args
.
exp
}
'
]()
class
CustomLightningModule
(
LightningModule
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layer1
=
nn
.
Linear
(
32
,
2
)
self
.
layer2
=
nn
.
LayerChoice
([
nn
.
Linear
(
2
,
2
),
nn
.
Linear
(
2
,
2
,
bias
=
False
)])
def
forward
(
self
,
x
):
return
self
.
layer2
(
self
.
layer1
(
x
))
def
configure_optimizers
(
self
):
opt1
=
torch
.
optim
.
SGD
(
self
.
layer1
.
parameters
(),
lr
=
0.1
)
opt2
=
torch
.
optim
.
Adam
(
self
.
layer2
.
parameters
(),
lr
=
0.2
)
return
[
opt1
,
opt2
],
[
torch
.
optim
.
lr_scheduler
.
StepLR
(
opt1
,
step_size
=
2
,
gamma
=
0.1
)]
def
training_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'train_loss'
,
loss
)
return
{
'loss'
:
loss
}
def
on_train_epoch_start
(
self
)
->
None
:
learning_rates
.
append
(
self
.
optimizers
()[
0
].
param_groups
[
0
][
'lr'
])
def
validation_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'valid_loss'
,
loss
)
def
test_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'test_loss'
,
loss
)
train_data
=
RandomDataset
(
32
,
32
)
valid_data
=
RandomDataset
(
32
,
16
)
model
=
CustomLightningModule
()
darts_module
=
DartsLightningModule
(
model
,
gradient_clip_val
=
5
)
trainer
=
Trainer
(
max_epochs
=
10
)
trainer
.
fit
(
darts_module
,
dict
(
train
=
DataLoader
(
train_data
,
batch_size
=
8
),
val
=
DataLoader
(
valid_data
,
batch_size
=
8
))
)
assert
len
(
learning_rates
)
==
10
and
abs
(
learning_rates
[
0
]
-
0.1
)
<
1e-5
and
\
abs
(
learning_rates
[
2
]
-
0.01
)
<
1e-5
and
abs
(
learning_rates
[
-
1
]
-
1e-5
)
<
1e-6
test/algo/nas/test_oneshot_proxyless.py
0 → 100644
View file @
f77db747
import
torch
import
torch.nn
as
nn
from
nni.nas.hub.pytorch.nasbench201
import
OPS_WITH_STRIDE
from
nni.nas.oneshot.pytorch.supermodule.proxyless
import
ProxylessMixedLayer
,
ProxylessMixedInput
,
_iter_tensors
def
test_proxyless_bp
():
op
=
ProxylessMixedLayer
(
[(
name
,
value
(
3
,
3
,
1
))
for
name
,
value
in
OPS_WITH_STRIDE
.
items
()],
nn
.
Parameter
(
torch
.
randn
(
len
(
OPS_WITH_STRIDE
))),
nn
.
Softmax
(
-
1
),
'proxyless'
)
optimizer
=
torch
.
optim
.
SGD
(
op
.
parameters
(
arch
=
True
),
0.1
)
for
_
in
range
(
10
):
x
=
torch
.
randn
(
1
,
3
,
9
,
9
).
requires_grad_
()
op
.
resample
({})
y
=
op
(
x
).
sum
()
optimizer
.
zero_grad
()
y
.
backward
()
assert
op
.
_arch_alpha
.
grad
.
abs
().
sum
().
item
()
!=
0
def
test_proxyless_input
():
inp
=
ProxylessMixedInput
(
6
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
6
)),
nn
.
Softmax
(
-
1
),
'proxyless'
)
optimizer
=
torch
.
optim
.
SGD
(
inp
.
parameters
(
arch
=
True
),
0.1
)
for
_
in
range
(
10
):
x
=
[
torch
.
randn
(
1
,
3
,
9
,
9
).
requires_grad_
()
for
_
in
range
(
6
)]
inp
.
resample
({})
y
=
inp
(
x
).
sum
()
optimizer
.
zero_grad
()
y
.
backward
()
def
test_iter_tensors
():
a
=
(
torch
.
zeros
(
3
,
1
),
{
'a'
:
torch
.
zeros
(
5
,
1
),
'b'
:
torch
.
zeros
(
6
,
1
)},
[
torch
.
zeros
(
7
,
1
)])
ret
=
[]
for
x
in
_iter_tensors
(
a
):
ret
.
append
(
x
.
shape
[
0
])
assert
ret
==
[
3
,
5
,
6
,
7
]
class
MultiInputLayer
(
nn
.
Module
):
def
__init__
(
self
,
d
):
super
().
__init__
()
self
.
d
=
d
def
forward
(
self
,
q
,
k
,
v
=
None
,
mask
=
None
):
return
q
+
self
.
d
,
2
*
k
-
2
*
self
.
d
,
v
,
mask
def
test_proxyless_multi_input
():
op
=
ProxylessMixedLayer
(
[
(
'a'
,
MultiInputLayer
(
1
)),
(
'b'
,
MultiInputLayer
(
3
))
],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'proxyless'
)
optimizer
=
torch
.
optim
.
SGD
(
op
.
parameters
(
arch
=
True
),
0.1
)
for
retry
in
range
(
10
):
q
=
torch
.
randn
(
1
,
3
,
9
,
9
).
requires_grad_
()
k
=
torch
.
randn
(
1
,
3
,
9
,
8
).
requires_grad_
()
v
=
None
if
retry
<
5
else
torch
.
randn
(
1
,
3
,
9
,
7
).
requires_grad_
()
mask
=
None
if
retry
%
5
<
2
else
torch
.
randn
(
1
,
3
,
9
,
6
).
requires_grad_
()
op
.
resample
({})
y
=
op
(
q
,
k
,
v
,
mask
=
mask
)
y
=
y
[
0
].
sum
()
+
y
[
1
].
sum
()
optimizer
.
zero_grad
()
y
.
backward
()
assert
op
.
_arch_alpha
.
grad
.
abs
().
sum
().
item
()
!=
0
,
op
.
_arch_alpha
.
grad
test/algo/nas/test_oneshot_supermodules.py
View file @
f77db747
...
...
@@ -3,7 +3,7 @@ import pytest
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
Conv2d
,
BatchNorm2d
,
LayerNorm
,
Linear
,
MultiheadAttention
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
LayerChoice
,
Conv2d
,
BatchNorm2d
,
LayerNorm
,
Linear
,
MultiheadAttention
from
nni.retiarii.oneshot.pytorch.base_lightning
import
traverse_and_mutate_submodules
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
(
MixedOpDifferentiablePolicy
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
GumbelSoftmax
,
...
...
@@ -144,6 +144,16 @@ def test_differentiable_valuechoice():
assert
set
(
conv
.
export
({}).
keys
())
==
{
'123'
,
'456'
}
def
test_differentiable_layerchoice_dedup
():
layerchoice1
=
LayerChoice
([
Conv2d
(
3
,
3
,
3
),
Conv2d
(
3
,
3
,
3
)],
label
=
'a'
)
layerchoice2
=
LayerChoice
([
Conv2d
(
3
,
3
,
3
),
Conv2d
(
3
,
3
,
3
)],
label
=
'a'
)
memo
=
{}
DifferentiableMixedLayer
.
mutate
(
layerchoice1
,
'x'
,
memo
,
{})
DifferentiableMixedLayer
.
mutate
(
layerchoice2
,
'x'
,
memo
,
{})
assert
len
(
memo
)
==
1
and
'a'
in
memo
def
_mixed_operation_sampling_sanity_check
(
operation
,
memo
,
*
input
):
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
if
native_op
.
bound_type
==
type
(
operation
):
...
...
@@ -160,7 +170,9 @@ def _mixed_operation_differentiable_sanity_check(operation, *input):
mutate_op
=
native_op
.
mutate
(
operation
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
})
break
return
mutate_op
(
*
input
)
mutate_op
(
*
input
)
mutate_op
.
export
({})
mutate_op
.
export_probs
({})
def
test_mixed_linear
():
...
...
@@ -319,6 +331,9 @@ def test_differentiable_layer_input():
op
=
DifferentiableMixedLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
export
({})[
'eee'
]
in
[
'a'
,
'b'
]
probs
=
op
.
export_probs
({})
assert
len
(
probs
)
==
2
assert
abs
(
probs
[
'eee/a'
]
+
probs
[
'eee/b'
]
-
1
)
<
1e-4
assert
len
(
list
(
op
.
parameters
()))
==
3
with
pytest
.
raises
(
ValueError
):
...
...
@@ -328,6 +343,8 @@ def test_differentiable_layer_input():
input
=
DifferentiableMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
(
-
1
)
==
2
assert
len
(
input
.
export
({})[
'ddd'
])
==
2
assert
len
(
input
.
export_probs
({}))
==
5
assert
'ddd/3'
in
input
.
export_probs
({})
def
test_proxyless_layer_input
():
...
...
@@ -341,7 +358,8 @@ def test_proxyless_layer_input():
input
=
ProxylessMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
.
resample
({})[
'ddd'
]
in
list
(
range
(
5
))
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
()
==
torch
.
Size
([
4
,
2
])
assert
input
.
export
({})[
'ddd'
]
in
list
(
range
(
5
))
exported
=
input
.
export
({})[
'ddd'
]
assert
len
(
exported
)
==
2
and
all
(
e
in
list
(
range
(
5
))
for
e
in
exported
)
def
test_pathsampling_repeat
():
...
...
@@ -373,6 +391,7 @@ def test_differentiable_repeat():
assert
op
(
torch
.
randn
(
2
,
8
)).
size
()
==
torch
.
Size
([
2
,
16
])
sample
=
op
.
export
({})
assert
'ccc'
in
sample
and
sample
[
'ccc'
]
in
[
0
,
1
]
assert
sorted
(
op
.
export_probs
({}).
keys
())
==
[
'ccc/0'
,
'ccc/1'
]
class
TupleModule
(
nn
.
Module
):
def
__init__
(
self
,
num
):
...
...
@@ -452,11 +471,16 @@ def test_differentiable_cell():
result
.
update
(
module
.
export
(
memo
=
result
))
assert
len
(
result
)
==
model
.
cell
.
num_nodes
*
model
.
cell
.
num_ops_per_node
*
2
result_prob
=
{}
for
module
in
nas_modules
:
result_prob
.
update
(
module
.
export_probs
(
memo
=
result_prob
))
ctrl_params
=
[]
for
m
in
nas_modules
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
if
cell_cls
in
[
CellLooseEnd
,
CellOpFactory
]:
assert
len
(
ctrl_params
)
==
model
.
cell
.
num_nodes
*
(
model
.
cell
.
num_nodes
+
3
)
//
2
assert
len
(
result_prob
)
==
len
(
ctrl_params
)
*
2
# len(op_names) == 2
assert
isinstance
(
model
.
cell
,
DifferentiableMixedCell
)
else
:
assert
not
isinstance
(
model
.
cell
,
DifferentiableMixedCell
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment