Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f77db747
Unverified
Commit
f77db747
authored
Aug 15, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 15, 2022
Browse files
Enhancement of one-shot NAS (v2.9) (#5049)
parent
125ec21f
Changes
15
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
730 additions
and
399 deletions
+730
-399
nni/nas/hub/pytorch/modules/nasbench201.py
nni/nas/hub/pytorch/modules/nasbench201.py
+1
-1
nni/nas/hub/pytorch/nasbench201.py
nni/nas/hub/pytorch/nasbench201.py
+4
-1
nni/nas/oneshot/pytorch/base_lightning.py
nni/nas/oneshot/pytorch/base_lightning.py
+213
-203
nni/nas/oneshot/pytorch/differentiable.py
nni/nas/oneshot/pytorch/differentiable.py
+46
-27
nni/nas/oneshot/pytorch/enas.py
nni/nas/oneshot/pytorch/enas.py
+8
-3
nni/nas/oneshot/pytorch/sampling.py
nni/nas/oneshot/pytorch/sampling.py
+62
-18
nni/nas/oneshot/pytorch/supermodule/_valuechoice_utils.py
nni/nas/oneshot/pytorch/supermodule/_valuechoice_utils.py
+2
-2
nni/nas/oneshot/pytorch/supermodule/base.py
nni/nas/oneshot/pytorch/supermodule/base.py
+11
-0
nni/nas/oneshot/pytorch/supermodule/differentiable.py
nni/nas/oneshot/pytorch/supermodule/differentiable.py
+76
-10
nni/nas/oneshot/pytorch/supermodule/operation.py
nni/nas/oneshot/pytorch/supermodule/operation.py
+8
-0
nni/nas/oneshot/pytorch/supermodule/proxyless.py
nni/nas/oneshot/pytorch/supermodule/proxyless.py
+140
-117
nni/nas/oneshot/pytorch/supermodule/sampling.py
nni/nas/oneshot/pytorch/supermodule/sampling.py
+5
-1
test/algo/nas/test_oneshot.py
test/algo/nas/test_oneshot.py
+50
-13
test/algo/nas/test_oneshot_proxyless.py
test/algo/nas/test_oneshot_proxyless.py
+77
-0
test/algo/nas/test_oneshot_supermodules.py
test/algo/nas/test_oneshot_supermodules.py
+27
-3
No files found.
nni/nas/hub/pytorch/modules/nasbench201.py
View file @
f77db747
...
...
@@ -70,7 +70,7 @@ class NasBench201Cell(nn.Module):
inp
=
in_features
if
j
==
0
else
out_features
op_choices
=
OrderedDict
([(
key
,
cls
(
inp
,
out_features
))
for
key
,
cls
in
op_candidates
.
items
()])
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
__
{
j
}
_
{
tid
}
'
))
# put __ here to be compatible with base engine
node_ops
.
append
(
LayerChoice
(
op_choices
,
label
=
f
'
{
self
.
_label
}
/
{
j
}
_
{
tid
}
'
))
self
.
layers
.
append
(
node_ops
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
nni/nas/hub/pytorch/nasbench201.py
View file @
f77db747
...
...
@@ -179,7 +179,7 @@ class NasBench201(nn.Module):
cell
=
ResNetBasicblock
(
C_prev
,
C_curr
,
2
)
else
:
ops
:
Dict
[
str
,
Callable
[[
int
,
int
],
nn
.
Module
]]
=
{
prim
:
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
for
prim
in
PRIMITIVES
prim
:
self
.
_make_op_factory
(
prim
)
for
prim
in
PRIMITIVES
}
cell
=
NasBench201Cell
(
ops
,
C_prev
,
C_curr
,
label
=
'cell'
)
self
.
cells
.
append
(
cell
)
...
...
@@ -192,6 +192,9 @@ class NasBench201(nn.Module):
self
.
global_pooling
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
classifier
=
nn
.
Linear
(
C_prev
,
self
.
num_labels
)
def
_make_op_factory
(
self
,
prim
):
return
lambda
C_in
,
C_out
:
OPS_WITH_STRIDE
[
prim
](
C_in
,
C_out
,
1
)
def
forward
(
self
,
inputs
):
feature
=
self
.
stem
(
inputs
)
for
cell
in
self
.
cells
:
...
...
nni/nas/oneshot/pytorch/base_lightning.py
View file @
f77db747
This diff is collapsed.
Click to expand it.
nni/nas/oneshot/pytorch/differentiable.py
View file @
f77db747
...
...
@@ -9,7 +9,7 @@ import pytorch_lightning as pl
import
torch
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.base_lightning
import
BaseOneShotLightningModule
,
MANUAL_OPTIMIZATION_NOTE
,
MutationHook
,
no_default_hook
from
.supermodule.differentiable
import
(
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
MixedOpDifferentiablePolicy
,
GumbelSoftmax
,
...
...
@@ -28,6 +28,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
In both phases, ``training_step`` of the Lightning evaluator will be used.
The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet.
...
...
@@ -49,15 +50,20 @@ class DartsLightningModule(BaseOneShotLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
),
optimization_note
=
MANUAL_OPTIMIZATION_NOTE
)
__doc__
=
_darts_note
.
format
(
...
...
@@ -85,8 +91,10 @@ class DartsLightningModule(BaseOneShotLightningModule):
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0E-4
):
arc_learning_rate
:
float
=
3.0E-4
,
gradient_clip_val
:
float
|
None
=
None
):
self
.
arc_learning_rate
=
arc_learning_rate
self
.
gradient_clip_val
=
gradient_clip_val
super
().
__init__
(
inner_module
,
mutation_hooks
=
mutation_hooks
)
def
training_step
(
self
,
batch
,
batch_idx
):
...
...
@@ -108,33 +116,32 @@ class DartsLightningModule(BaseOneShotLightningModule):
if
isinstance
(
arc_step_loss
,
dict
):
arc_step_loss
=
arc_step_loss
[
'loss'
]
self
.
manual_backward
(
arc_step_loss
)
self
.
finalize_grad
()
arc_optim
.
step
()
# phase 2: model step
self
.
resample
()
self
.
call_weight_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
call_weight_optimizers
(
'step'
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
advance_optimization
(
w_step_loss
,
batch_idx
,
self
.
gradient_clip_val
)
self
.
call_lr_schedulers
(
batch_idx
)
# Update learning rates
self
.
advance_lr_schedulers
(
batch_idx
)
return
loss_and_metrics
self
.
log_dict
({
'prob/'
+
k
:
v
for
k
,
v
in
self
.
export_probs
().
items
()})
def
finalize_grad
(
self
):
# Note: This hook is currently kept for Proxyless NAS.
pass
return
loss_and_metrics
def
configure_architecture_optimizers
(
self
):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params
=
[]
for
m
in
self
.
nas_modules
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
# type: ignore
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
set
(
ctrl_params
)),
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
# Follow the hyper-parameters used in
# https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/architect.py#L17
params
=
list
(
set
(
ctrl_params
))
if
not
params
:
raise
ValueError
(
'No architecture parameters found. Nothing to search.'
)
ctrl_optim
=
torch
.
optim
.
Adam
(
params
,
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
return
ctrl_optim
...
...
@@ -153,13 +160,20 @@ class ProxylessLightningModule(DartsLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
optimization_note
=
MANUAL_OPTIMIZATION_NOTE
)
__doc__
=
_proxyless_note
.
format
(
module_notes
=
'This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.'
,
...
...
@@ -176,10 +190,6 @@ class ProxylessLightningModule(DartsLightningModule):
# FIXME: no support for mixed operation currently
return
hooks
def
finalize_grad
(
self
):
for
m
in
self
.
nas_modules
:
m
.
finalize_grad
()
# type: ignore
class
GumbelDartsLightningModule
(
DartsLightningModule
):
_gumbel_darts_note
=
"""
...
...
@@ -207,6 +217,8 @@ class GumbelDartsLightningModule(DartsLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
...
...
@@ -216,13 +228,17 @@ class GumbelDartsLightningModule(DartsLightningModule):
use_temp_anneal : bool
If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details.
Default is false.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
),
optimization_note
=
MANUAL_OPTIMIZATION_NOTE
)
def
mutate_kwargs
(
self
):
...
...
@@ -235,22 +251,25 @@ class GumbelDartsLightningModule(DartsLightningModule):
def
__init__
(
self
,
inner_module
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0e-4
,
gradient_clip_val
:
float
|
None
=
None
,
gumbel_temperature
:
float
=
1.
,
use_temp_anneal
:
bool
=
False
,
min_temp
:
float
=
.
33
):
super
().
__init__
(
inner_module
,
mutation_hooks
,
arc_learning_rate
=
arc_learning_rate
)
super
().
__init__
(
inner_module
,
mutation_hooks
,
arc_learning_rate
=
arc_learning_rate
,
gradient_clip_val
=
gradient_clip_val
)
self
.
temp
=
gumbel_temperature
self
.
init_temp
=
gumbel_temperature
self
.
use_temp_anneal
=
use_temp_anneal
self
.
min_temp
=
min_temp
def
on_train_epoch_
end
(
self
):
def
on_train_epoch_
start
(
self
):
if
self
.
use_temp_anneal
:
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
self
.
log
(
'gumbel_temperature'
,
self
.
temp
)
for
module
in
self
.
nas_modules
:
if
hasattr
(
module
,
'_softmax'
):
module
.
_softmax
.
t
emp
=
self
.
temp
# type: ignore
if
hasattr
(
module
,
'_softmax'
)
and
isinstance
(
module
,
GumbelSoftmax
)
:
module
.
_softmax
.
t
au
=
self
.
temp
# type: ignore
return
self
.
model
.
on_train_epoch_
end
()
return
self
.
model
.
on_train_epoch_
start
()
nni/nas/oneshot/pytorch/enas.py
View file @
f77db747
...
...
@@ -94,11 +94,11 @@ class ReinforceController(nn.Module):
field
.
name
:
nn
.
Embedding
(
field
.
total
,
self
.
lstm_size
)
for
field
in
fields
})
def
resample
(
self
):
def
resample
(
self
,
return_prob
=
False
):
self
.
_initialize
()
result
=
dict
()
for
field
in
self
.
fields
:
result
[
field
.
name
]
=
self
.
_sample_single
(
field
)
result
[
field
.
name
]
=
self
.
_sample_single
(
field
,
return_prob
=
return_prob
)
return
result
def
_initialize
(
self
):
...
...
@@ -116,7 +116,7 @@ class ReinforceController(nn.Module):
def
_lstm_next_step
(
self
):
self
.
_h
,
self
.
_c
=
self
.
lstm
(
self
.
_inputs
,
(
self
.
_h
,
self
.
_c
))
def
_sample_single
(
self
,
field
):
def
_sample_single
(
self
,
field
,
return_prob
):
self
.
_lstm_next_step
()
logit
=
self
.
soft
[
field
.
name
](
self
.
_h
[
-
1
])
if
self
.
temperature
is
not
None
:
...
...
@@ -124,10 +124,12 @@ class ReinforceController(nn.Module):
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
if
field
.
choose_one
:
sampled_dist
=
F
.
softmax
(
logit
,
dim
=-
1
)
sampled
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
sampled
)
self
.
_inputs
=
self
.
embedding
[
field
.
name
](
sampled
)
else
:
sampled_dist
=
torch
.
sigmoid
(
logit
)
logit
=
logit
.
view
(
-
1
,
1
)
logit
=
torch
.
cat
([
-
logit
,
logit
],
1
)
# pylint: disable=invalid-unary-operand-type
sampled
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
...
...
@@ -147,4 +149,7 @@ class ReinforceController(nn.Module):
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
if
len
(
sampled
)
==
1
:
sampled
=
sampled
[
0
]
if
return_prob
:
return
sampled_dist
.
flatten
().
detach
().
cpu
().
numpy
().
tolist
()
return
sampled
nni/nas/oneshot/pytorch/sampling.py
View file @
f77db747
...
...
@@ -5,14 +5,14 @@
from
__future__
import
annotations
import
warnings
from
typing
import
Any
from
typing
import
Any
,
cast
import
pytorch_lightning
as
pl
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.base_lightning
import
MANUAL_OPTIMIZATION_NOTE
,
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
,
NATIVE_SUPPORTED_OP_NAMES
from
.supermodule.sampling
import
(
PathSamplingInput
,
PathSamplingLayer
,
MixedOpPathSamplingPolicy
,
...
...
@@ -37,6 +37,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
This strategy assumes inner evaluator has set
`automatic optimization <https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html>`__ to true.
Parameters
----------
{{module_params}}
...
...
@@ -73,9 +76,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
}
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
*
args
,
**
kwargs
):
self
.
resample
()
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
return
self
.
model
.
training_step
(
*
args
,
**
kwargs
)
def
export
(
self
)
->
dict
[
str
,
Any
]:
"""
...
...
@@ -115,6 +118,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
...
...
@@ -133,6 +138,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
before updating the weights of RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
log_prob_every_n_step : int
Log the probability of choices every N steps. Useful for visualization and debugging.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
...
...
@@ -141,11 +148,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Otherwise it raises an exception indicating multiple metrics are found.
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
),
optimization_note
=
MANUAL_OPTIMIZATION_NOTE
)
__doc__
=
_enas_note
.
format
(
module_notes
=
'``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.
utils.ConcatenateTrainValDatal
oader`.'
,
module_notes
=
'``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.
pytorch.dataloader.ConcatL
oader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
...
...
@@ -162,6 +170,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
baseline_decay
:
float
=
.
999
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_grad_clip
:
float
=
0
,
log_prob_every_n_step
:
int
=
10
,
reward_metric_name
:
str
|
None
=
None
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
):
super
().
__init__
(
inner_module
,
mutation_hooks
)
...
...
@@ -181,33 +190,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
self
.
baseline
=
0.
self
.
ctrl_steps_aggregate
=
ctrl_steps_aggregate
self
.
ctrl_grad_clip
=
ctrl_grad_clip
self
.
log_prob_every_n_step
=
log_prob_every_n_step
self
.
reward_metric_name
=
reward_metric_name
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
def
training_step
(
self
,
batch_packed
,
batch_idx
):
# The received batch is a tuple of (data, "train" | "val")
batch
,
mode
=
batch_packed
if
mode
==
'train'
:
# train model params
with
torch
.
no_grad
():
self
.
resample
()
self
.
call_weight_optimizers
(
'zero_grad'
)
step_output
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
step_output
[
'loss'
]
\
if
isinstance
(
step_output
,
dict
)
else
step_output
self
.
manual_backward
(
w_step_loss
)
self
.
call_weight_optimizers
(
'step'
)
w_step_loss
=
step_output
[
'loss'
]
if
isinstance
(
step_output
,
dict
)
else
step_output
self
.
advance_optimization
(
w_step_loss
,
batch_idx
)
else
:
# train ENAS agent
arc_opt
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
self
.
resample
()
# Run a sample to retrieve the reward
self
.
resample
()
step_output
=
self
.
model
.
validation_step
(
batch
,
batch_idx
)
# use the default metric of self.model as reward function
...
...
@@ -218,11 +223,13 @@ class EnasLightningModule(RandomSamplingLightningModule):
if
metric_name
not
in
self
.
trainer
.
callback_metrics
:
raise
KeyError
(
f
'Model reported metrics should contain a ``
{
metric_name
}
`` key but '
f
'found multiple (or zero) metrics without default:
{
list
(
self
.
trainer
.
callback_metrics
.
keys
())
}
. '
f
'Try to use self.log to report metrics with the specified key ``
{
metric_name
}
`` in validation_step, '
'and remember to set on_step=True.'
)
f
'Please try to set ``reward_metric_name`` to be one of the keys listed above. '
f
'If it is not working use self.log to report metrics with the specified key ``
{
metric_name
}
`` '
'in validation_step, and remember to set on_step=True.'
)
metric
=
self
.
trainer
.
callback_metrics
[
metric_name
]
reward
:
float
=
metric
.
item
()
# Compute the loss and run back propagation
if
self
.
entropy_weight
:
reward
=
reward
+
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
...
...
@@ -236,11 +243,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
if
(
batch_idx
+
1
)
%
self
.
ctrl_steps_aggregate
==
0
:
if
self
.
ctrl_grad_clip
>
0
:
nn
.
utils
.
clip_grad_norm_
(
self
.
controller
.
parameters
(),
self
.
ctrl_grad_clip
)
# Update the controller and zero out its gradients
arc_opt
=
cast
(
optim
.
Optimizer
,
self
.
architecture_optimizers
())
arc_opt
.
step
()
arc_opt
.
zero_grad
()
self
.
advance_lr_schedulers
(
batch_idx
)
if
(
batch_idx
+
1
)
%
self
.
log_prob_every_n_step
==
0
:
with
torch
.
no_grad
():
self
.
log_dict
({
'prob/'
+
k
:
v
for
k
,
v
in
self
.
export_probs
().
items
()})
return
step_output
def
on_train_epoch_start
(
self
):
# Always zero out the gradients of ENAS controller at the beginning of epochs.
arc_opt
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
return
self
.
model
.
on_train_epoch_start
()
def
resample
(
self
):
"""Resample the architecture with ENAS controller."""
sample
=
self
.
controller
.
resample
()
...
...
@@ -249,6 +274,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
module
.
resample
(
memo
=
result
)
return
result
def
export_probs
(
self
):
"""Export the probability from ENAS controller directly."""
sample
=
self
.
controller
.
resample
(
return_prob
=
True
)
result
=
self
.
_interpret_controller_probability_result
(
sample
)
for
module
in
self
.
nas_modules
:
module
.
resample
(
memo
=
result
)
return
result
def
export
(
self
):
"""Run one more inference of ENAS controller."""
self
.
controller
.
eval
()
...
...
@@ -261,3 +294,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
for
key
in
list
(
sample
.
keys
()):
sample
[
key
]
=
space_spec
[
key
].
values
[
sample
[
key
]]
return
sample
def
_interpret_controller_probability_result
(
self
,
sample
:
dict
[
str
,
list
[
float
]])
->
dict
[
str
,
Any
]:
"""Convert ``{label: [prob1, prob2, prob3]} to ``{label/choice: prob}``"""
space_spec
=
self
.
search_space_spec
()
result
=
{}
for
key
in
list
(
sample
.
keys
()):
if
len
(
space_spec
[
key
].
values
)
!=
len
(
sample
[
key
]):
raise
ValueError
(
f
'Expect
{
space_spec
[
key
].
values
}
to be of the same length as
{
sample
[
key
]
}
'
)
for
value
,
weight
in
zip
(
space_spec
[
key
].
values
,
sample
[
key
]):
result
[
f
'
{
key
}
/
{
value
}
'
]
=
weight
return
result
nni/nas/oneshot/pytorch/supermodule/_valuechoice_utils.py
View file @
f77db747
...
...
@@ -168,11 +168,11 @@ def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence
assert
len
(
items
)
==
len
(
weights
)
>
0
elem
=
items
[
0
]
unsupported_msg
=
f
'Unsupported element type in weighted sum:
{
type
(
elem
)
}
. Value is:
{
elem
}
'
unsupported_msg
=
'Unsupported element type in weighted sum: {}. Value is: {}'
if
isinstance
(
elem
,
str
):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise
TypeError
(
unsupported_msg
)
raise
TypeError
(
unsupported_msg
.
format
(
type
(
elem
),
elem
)
)
try
:
if
isinstance
(
elem
,
(
torch
.
Tensor
,
np
.
ndarray
,
float
,
int
,
np
.
number
)):
...
...
nni/nas/oneshot/pytorch/supermodule/base.py
View file @
f77db747
...
...
@@ -56,6 +56,17 @@ class BaseSuperNetModule(nn.Module):
"""
raise
NotImplementedError
()
def
export_probs
(
self
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""
Export the probability / logits of every choice got chosen.
Parameters
----------
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise
NotImplementedError
()
def
search_space_spec
(
self
)
->
dict
[
str
,
ParameterSpec
]:
"""
Space specification (sample points).
...
...
nni/nas/oneshot/pytorch/supermodule/differentiable.py
View file @
f77db747
...
...
@@ -104,6 +104,13 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
())]}
def
export_probs
(
self
,
memo
):
if
any
(
k
.
startswith
(
self
.
label
+
'/'
)
for
k
in
memo
):
return
{}
# nothing new
weights
=
self
.
_softmax
(
self
.
_arch_alpha
).
cpu
().
tolist
()
ret
=
{
f
'
{
self
.
label
}
/
{
name
}
'
:
value
for
name
,
value
in
zip
(
self
.
op_names
,
weights
)}
return
ret
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
True
,
size
=
len
(
self
.
op_names
))}
...
...
@@ -117,7 +124,8 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
if
len
(
alpha
)
!=
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
module
.
label
}
conflict:
{
len
(
alpha
)
}
vs.
{
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# this can be reinitialized later
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# the numbers in the parameter can be reinitialized later
memo
[
module
.
label
]
=
alpha
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
list
(
module
.
named_children
()),
alpha
,
softmax
,
module
.
label
)
...
...
@@ -208,6 +216,13 @@ class DifferentiableMixedInput(BaseSuperNetModule):
chosen
=
chosen
[
0
]
return
{
self
.
label
:
chosen
}
def
export_probs
(
self
,
memo
):
if
any
(
k
.
startswith
(
self
.
label
+
'/'
)
for
k
in
memo
):
return
{}
# nothing new
weights
=
self
.
_softmax
(
self
.
_arch_alpha
).
cpu
().
tolist
()
ret
=
{
f
'
{
self
.
label
}
/
{
index
}
'
:
value
for
index
,
value
in
enumerate
(
weights
)}
return
ret
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
list
(
range
(
self
.
n_candidates
)),
...
...
@@ -225,7 +240,8 @@ class DifferentiableMixedInput(BaseSuperNetModule):
if
len
(
alpha
)
!=
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
module
.
label
}
conflict:
{
len
(
alpha
)
}
vs.
{
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# this can be reinitialized later
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# the numbers in the parameter can be reinitialized later
memo
[
module
.
label
]
=
alpha
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
alpha
,
softmax
,
module
.
label
)
...
...
@@ -284,6 +300,7 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
raise
ValueError
(
f
'Architecture parameter size of same label
{
name
}
conflict:
{
len
(
alpha
)
}
vs.
{
spec
.
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
memo
[
name
]
=
alpha
operation
.
_arch_alpha
[
name
]
=
alpha
operation
.
parameters
=
functools
.
partial
(
self
.
parameters
,
module
=
operation
)
# bind self
...
...
@@ -321,6 +338,16 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
export_probs
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
]):
"""Export the weight for every leaf value choice."""
ret
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
any
(
k
.
startswith
(
name
+
'/'
)
for
k
in
memo
):
continue
weights
=
operation
.
_softmax
(
operation
.
_arch_alpha
[
name
]).
cpu
().
tolist
()
# type: ignore
ret
.
update
({
f
'
{
name
}
/
{
value
}
'
:
weight
for
value
,
weight
in
zip
(
spec
.
values
,
weights
)})
return
ret
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
dict
[
Any
,
float
]
|
Any
:
if
name
in
operation
.
mutable_arguments
:
weights
:
dict
[
str
,
torch
.
Tensor
]
=
{
...
...
@@ -360,6 +387,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
raise
ValueError
(
f
'Architecture parameter size of same label
{
name
}
conflict:
{
len
(
alpha
)
}
vs.
{
spec
.
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
memo
[
name
]
=
alpha
self
.
_arch_alpha
[
name
]
=
alpha
def
resample
(
self
,
memo
):
...
...
@@ -376,6 +404,16 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
export_probs
(
self
,
memo
):
"""Export the weight for every leaf value choice."""
ret
=
{}
for
name
,
spec
in
self
.
search_space_spec
().
items
():
if
any
(
k
.
startswith
(
name
+
'/'
)
for
k
in
memo
):
continue
weights
=
self
.
_softmax
(
self
.
_arch_alpha
[
name
]).
cpu
().
tolist
()
ret
.
update
({
f
'
{
name
}
/
{
value
}
'
:
weight
for
value
,
weight
in
zip
(
spec
.
values
,
weights
)})
return
ret
def
search_space_spec
(
self
):
return
self
.
_space_spec
...
...
@@ -427,6 +465,8 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
class
DifferentiableMixedCell
(
PathSamplingCell
):
"""Implementation of Cell under differentiable context.
Similar to PathSamplingCell, this cell only handles cells of specific kinds (e.g., with loose end).
An architecture parameter is created on each edge of the full-connected graph.
"""
...
...
@@ -450,13 +490,21 @@ class DifferentiableMixedCell(PathSamplingCell):
op
=
cast
(
List
[
Dict
[
str
,
nn
.
Module
]],
self
.
ops
[
i
-
self
.
num_predecessors
])[
j
]
if
edge_label
in
memo
:
alpha
=
memo
[
edge_label
]
if
len
(
alpha
)
!=
len
(
op
)
+
1
:
if
len
(
alpha
)
!=
len
(
op
):
raise
ValueError
(
f
'Architecture parameter size of same label
{
edge_label
}
conflict: '
f
'
{
len
(
alpha
)
}
vs.
{
len
(
op
)
}
'
)
warnings
.
warn
(
f
'Architecture parameter size
{
len
(
alpha
)
}
is not same as expected:
{
len
(
op
)
+
1
}
. '
'This is likely due to the label being shared by a LayerChoice inside the cell and outside.'
,
UserWarning
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
op
))
*
1E-3
)
# +1 to emulate the input choice.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
op
)
+
1
)
*
1E-3
)
memo
[
edge_label
]
=
alpha
self
.
_arch_alpha
[
edge_label
]
=
alpha
self
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
...
...
@@ -465,18 +513,32 @@ class DifferentiableMixedCell(PathSamplingCell):
"""Differentiable doesn't need to resample."""
return
{}
def
export_probs
(
self
,
memo
):
"""When export probability, we follow the structure in arch alpha."""
ret
=
{}
for
name
,
parameter
in
self
.
_arch_alpha
.
items
():
if
any
(
k
.
startswith
(
name
+
'/'
)
for
k
in
memo
):
continue
weights
=
self
.
_softmax
(
parameter
).
cpu
().
tolist
()
ret
.
update
({
f
'
{
name
}
/
{
value
}
'
:
weight
for
value
,
weight
in
zip
(
self
.
op_names
,
weights
)})
return
ret
def
export
(
self
,
memo
):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported
=
{}
for
i
in
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
):
# If label already exists, no need to re-export.
if
all
(
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
in
memo
and
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
in
memo
for
k
in
range
(
self
.
num_ops_per_node
)):
continue
# Tuple of (weight, input_index, op_name)
all_weights
:
list
[
tuple
[
float
,
int
,
str
]]
=
[]
for
j
in
range
(
i
):
for
k
,
name
in
enumerate
(
self
.
op_names
):
# The last appended weight is automatically skipped in export.
all_weights
.
append
((
float
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
][
k
].
item
()),
j
,
name
,
...
...
@@ -497,7 +559,7 @@ class DifferentiableMixedCell(PathSamplingCell):
all_weights
=
[
all_weights
[
k
]
for
k
in
first_occurrence_index
]
+
\
[
w
for
j
,
w
in
enumerate
(
all_weights
)
if
j
not
in
first_occurrence_index
]
_logger
.
info
(
'Sorted weights in differentiable cell export (node %d): %s'
,
i
,
all_weights
)
_logger
.
info
(
'Sorted weights in differentiable cell export (
%s cell,
node %d): %s'
,
self
.
label
,
i
,
all_weights
)
for
k
in
range
(
self
.
num_ops_per_node
):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
...
...
@@ -515,7 +577,11 @@ class DifferentiableMixedCell(PathSamplingCell):
for
j
in
range
(
i
):
# for every previous tensors
op_results
=
torch
.
stack
([
op
(
states
[
j
])
for
op
in
ops
[
j
].
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
# (-1, 1, 1, 1, 1, ...)
op_weights
=
self
.
_softmax
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
])
if
len
(
op_weights
)
==
len
(
op_results
)
+
1
:
# concatenate with a zero operation, indicating this path is not chosen at all.
op_results
=
torch
.
cat
((
op_results
,
torch
.
zeros_like
(
op_results
[:
1
])),
0
)
edge_sum
=
torch
.
sum
(
op_results
*
self
.
_softmax
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
]).
view
(
*
alpha_shape
),
0
)
current_state
.
append
(
edge_sum
)
...
...
nni/nas/oneshot/pytorch/supermodule/operation.py
View file @
f77db747
...
...
@@ -71,6 +71,10 @@ class MixedOperationSamplingPolicy:
"""The handler of :meth:`MixedOperation.export`."""
raise
NotImplementedError
()
def
export_probs
(
self
,
operation
:
'MixedOperation'
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.export_probs`."""
raise
NotImplementedError
()
def
forward_argument
(
self
,
operation
:
'MixedOperation'
,
name
:
str
)
->
Any
:
"""Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value.
...
...
@@ -162,6 +166,10 @@ class MixedOperation(BaseSuperNetModule):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return
self
.
sampling_policy
.
resample
(
self
,
memo
)
def
export_probs
(
self
,
memo
):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export_probs`."""
return
self
.
sampling_policy
.
export_probs
(
self
,
memo
)
def
export
(
self
,
memo
):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return
self
.
sampling_policy
.
export
(
self
,
memo
)
...
...
nni/nas/oneshot/pytorch/supermodule/proxyless.py
View file @
f77db747
...
...
@@ -11,7 +11,7 @@ The support remains limited. Known limitations include:
from
__future__
import
annotations
from
typing
import
cast
from
typing
import
Any
,
Tuple
,
Union
,
cast
import
torch
import
torch.nn
as
nn
...
...
@@ -21,28 +21,115 @@ from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput
__all__
=
[
'ProxylessMixedLayer'
,
'ProxylessMixedInput'
]
class
_ArchGradientFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
binary_gates
,
run_func
,
backward_func
):
ctx
.
run_func
=
run_func
ctx
.
backward_func
=
backward_func
def
_detach_tensor
(
tensor
:
Any
)
->
Any
:
"""Recursively detach all the tensors."""
if
isinstance
(
tensor
,
(
list
,
tuple
)):
return
tuple
(
_detach_tensor
(
t
)
for
t
in
tensor
)
elif
isinstance
(
tensor
,
dict
):
return
{
k
:
_detach_tensor
(
v
)
for
k
,
v
in
tensor
.
items
()}
elif
isinstance
(
tensor
,
torch
.
Tensor
):
return
tensor
.
detach
()
else
:
return
tensor
detached_x
=
x
.
detach
()
detached_x
.
requires_grad
=
x
.
requires_grad
with
torch
.
enable_grad
():
output
=
run_func
(
detached_x
)
ctx
.
save_for_backward
(
detached_x
,
output
)
return
output
.
data
def
_iter_tensors
(
tensor
:
Any
)
->
Any
:
"""Recursively iterate over all the tensors.
This is kept for complex outputs (like dicts / lists).
However, complex outputs are not supported by PyTorch backward hooks yet.
"""
if
isinstance
(
tensor
,
torch
.
Tensor
):
yield
tensor
elif
isinstance
(
tensor
,
(
list
,
tuple
)):
for
t
in
tensor
:
yield
from
_iter_tensors
(
t
)
elif
isinstance
(
tensor
,
dict
):
for
t
in
tensor
.
values
():
yield
from
_iter_tensors
(
t
)
def
_pack_as_tuple
(
tensor
:
Any
)
->
tuple
:
"""Return a tuple of tensor with only one element if tensor it's not a tuple."""
if
isinstance
(
tensor
,
(
tuple
,
list
)):
return
tuple
(
tensor
)
return
(
tensor
,)
def
element_product_sum
(
tensor1
:
tuple
[
torch
.
Tensor
,
...],
tensor2
:
tuple
[
torch
.
Tensor
,
...])
->
torch
.
Tensor
:
"""Compute the sum of all the element-wise product."""
assert
len
(
tensor1
)
==
len
(
tensor2
),
'The number of tensors must be the same.'
# Skip zero gradients
ret
=
[
torch
.
sum
(
t1
*
t2
)
for
t1
,
t2
in
zip
(
tensor1
,
tensor2
)
if
t1
is
not
None
and
t2
is
not
None
]
if
not
ret
:
return
torch
.
tensor
(
0
)
if
len
(
ret
)
==
1
:
return
ret
[
0
]
return
cast
(
torch
.
Tensor
,
sum
(
ret
))
class
ProxylessContext
:
def
__init__
(
self
,
arch_alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
)
->
None
:
self
.
arch_alpha
=
arch_alpha
self
.
softmax
=
softmax
# When a layer is called multiple times, the inputs and outputs are saved in order.
# In backward propagation, we assume that they are used in the reversed order.
self
.
layer_input
:
list
[
Any
]
=
[]
self
.
layer_output
:
list
[
Any
]
=
[]
self
.
layer_sample_idx
:
list
[
int
]
=
[]
def
clear_context
(
self
)
->
None
:
self
.
layer_input
=
[]
self
.
layer_output
=
[]
self
.
layer_sample_idx
=
[]
def
save_forward_context
(
self
,
layer_input
:
Any
,
layer_output
:
Any
,
layer_sample_idx
:
int
):
self
.
layer_input
.
append
(
_detach_tensor
(
layer_input
))
self
.
layer_output
.
append
(
_detach_tensor
(
layer_output
))
self
.
layer_sample_idx
.
append
(
layer_sample_idx
)
def
backward_hook
(
self
,
module
:
nn
.
Module
,
grad_input
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
],
grad_output
:
Union
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
Tensor
])
->
None
:
# binary_grads is the gradient of binary gates.
# Binary gates is a one-hot tensor where 1 is on the sampled index, and others are 0.
# By chain rule, it's gradient is grad_output times the layer_output (of the corresponding path).
binary_grads
=
torch
.
zeros_like
(
self
.
arch_alpha
)
# Retrieve the layer input/output in reverse order.
if
not
self
.
layer_input
:
raise
ValueError
(
'Unexpected backward call. The saved context is empty.'
)
layer_input
=
self
.
layer_input
.
pop
()
layer_output
=
self
.
layer_output
.
pop
()
layer_sample_idx
=
self
.
layer_sample_idx
.
pop
()
@
staticmethod
def
backward
(
ctx
,
grad_output
):
detached_x
,
output
=
ctx
.
saved_tensors
with
torch
.
no_grad
():
# Compute binary grads.
for
k
in
range
(
len
(
binary_grads
)):
if
k
!=
layer_sample_idx
:
args
,
kwargs
=
layer_input
out_k
=
module
.
forward_path
(
k
,
*
args
,
**
kwargs
)
# type: ignore
else
:
out_k
=
layer_output
grad_x
=
torch
.
autograd
.
grad
(
output
,
detached_x
,
grad_output
,
only_inputs
=
True
)
# compute gradients w.r.t. binary_gates
binary_grads
=
ctx
.
backward_func
(
detached_x
.
data
,
output
.
data
,
grad_output
.
data
)
# FIXME: One limitation here is that out_k can't be complex objects like dict.
# I think it's also a limitation of backward hook.
binary_grads
[
k
]
=
element_product_sum
(
_pack_as_tuple
(
out_k
),
# In case out_k is a single tensor
_pack_as_tuple
(
grad_output
)
)
return
grad_x
[
0
],
binary_grads
,
None
,
None
# Compute the gradient of the arch_alpha, based on binary_grads.
if
self
.
arch_alpha
.
grad
is
None
:
self
.
arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
arch_alpha
)
probs
=
self
.
softmax
(
self
.
arch_alpha
)
for
i
in
range
(
len
(
self
.
arch_alpha
)):
for
j
in
range
(
len
(
self
.
arch_alpha
)):
# Arch alpha's gradients are accumulated for all backwards through this layer.
self
.
arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
class
ProxylessMixedLayer
(
DifferentiableMixedLayer
):
...
...
@@ -50,46 +137,32 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
It resamples a single-path every time, rather than go through the softmax.
"""
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
_arch_parameter_names
=
[
'_arch_alpha'
]
def
__init__
(
self
,
paths
:
list
[
tuple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
paths
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
paths
))
*
1E-3
)
# Binary gates should be created here, but it's not because it's never used in the forward pass.
# self._binary_gates = nn.Parameter(torch.zeros(len(paths)))
# like sampling-based methods, it has a ``_sampled``.
self
.
_sampled
:
str
|
None
=
None
self
.
_sample_idx
:
int
|
None
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
def
run_function
(
ops
,
active_id
,
**
kwargs
):
def
forward
(
_x
):
return
ops
[
active_id
](
_x
,
**
kwargs
)
return
forward
def
backward_function
(
ops
,
active_id
,
binary_gates
,
**
kwargs
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
len
(
ops
)):
if
k
!=
active_id
:
out_k
=
ops
[
k
](
_x
.
data
,
**
kwargs
)
else
:
out_k
=
_output
.
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
assert
len
(
args
)
==
1
,
'ProxylessMixedLayer only supports exactly one input argument.'
x
=
args
[
0
]
# arch_alpha could be shared by multiple layers,
# but binary_gates is owned by the current layer.
self
.
ctx
=
ProxylessContext
(
alpha
,
softmax
)
self
.
register_full_backward_hook
(
self
.
ctx
.
backward_hook
)
assert
self
.
_sampled
is
not
None
,
'Need to call resample() before running fprop.'
list_ops
=
[
getattr
(
self
,
op
)
for
op
in
self
.
op_names
]
def
forward
(
self
,
*
args
,
**
kwargs
):
"""Forward pass of one single path."""
if
self
.
_sample_idx
is
None
:
raise
RuntimeError
(
'resample() needs to be called before fprop.'
)
output
=
self
.
forward_path
(
self
.
_sample_idx
,
*
args
,
**
kwargs
)
self
.
ctx
.
save_forward_context
((
args
,
kwargs
),
output
,
self
.
_sample_idx
)
return
output
return
_ArchGradientFunction
.
apply
(
x
,
self
.
_binary_gates
,
run_function
(
list_ops
,
self
.
_sample_idx
,
**
kwargs
),
backward_function
(
list_ops
,
self
.
_sample_idx
,
self
.
_binary_gates
,
**
kwargs
)
)
def
forward_path
(
self
,
index
,
*
args
,
**
kwargs
):
return
getattr
(
self
,
self
.
op_names
[
index
])(
*
args
,
**
kwargs
)
def
resample
(
self
,
memo
):
"""Sample one path based on alpha if label is not found in memo."""
...
...
@@ -101,66 +174,37 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self
.
_sample_idx
=
int
(
torch
.
multinomial
(
probs
,
1
)[
0
].
item
())
self
.
_sampled
=
self
.
op_names
[
self
.
_sample_idx
]
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
self
.
_sample_idx
]
=
1.0
self
.
ctx
.
clear_context
()
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
())]}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
for
i
in
range
(
len
(
self
.
_arch_alpha
)):
for
j
in
range
(
len
(
self
.
_arch_alpha
)):
self
.
_arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
class
ProxylessMixedInput
(
DifferentiableMixedInput
):
"""Proxyless version of differentiable input choice.
See :class:`ProxylessLayer
Choice
` for implementation details.
See :class:`Proxyless
Mixed
Layer` for implementation details.
"""
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
n_candidates
,
n_chosen
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
n_candidates
)
*
1E-3
)
# We only support choosing a particular one here.
# Nevertheless, we rank the score and export the tops in export.
self
.
_sampled
:
int
|
None
=
None
self
.
ctx
=
ProxylessContext
(
alpha
,
softmax
)
self
.
register_full_backward_hook
(
self
.
ctx
.
backward_hook
)
def
forward
(
self
,
inputs
):
def
run_function
(
active_sample
):
return
lambda
x
:
x
[
active_sample
]
"""Choose one single input."""
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'resample() needs to be called before fprop.'
)
output
=
self
.
forward_path
(
self
.
_sampled
,
inputs
)
self
.
ctx
.
save_forward_context
(((
inputs
,),
{}),
output
,
self
.
_sampled
)
return
output
def
backward_function
(
binary_gates
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
self
.
n_candidates
):
out_k
=
_x
[
k
].
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
inputs
=
torch
.
stack
(
inputs
,
0
)
assert
self
.
_sampled
is
not
None
,
'Need to call resample() before running fprop.'
return
_ArchGradientFunction
.
apply
(
inputs
,
self
.
_binary_gates
,
run_function
(
self
.
_sampled
),
backward_function
(
self
.
_binary_gates
)
)
def
forward_path
(
self
,
index
,
inputs
):
return
inputs
[
index
]
def
resample
(
self
,
memo
):
"""Sample one path based on alpha if label is not found in memo."""
...
...
@@ -171,27 +215,6 @@ class ProxylessMixedInput(DifferentiableMixedInput):
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sampled
=
int
(
sample
)
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
cast
(
int
,
self
.
_sampled
)]
=
1.0
self
.
ctx
.
clear_context
()
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
torch
.
argmax
(
self
.
_arch_alpha
).
item
()}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
for
i
in
range
(
self
.
n_candidates
):
for
j
in
range
(
self
.
n_candidates
):
self
.
_arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
nni/nas/oneshot/pytorch/supermodule/sampling.py
View file @
f77db747
...
...
@@ -169,7 +169,7 @@ class PathSamplingInput(BaseSuperNetModule):
class
MixedOpPathSamplingPolicy
(
MixedOperationSamplingPolicy
):
"""Implement
e
s the path sampling in mixed operation.
"""Implements the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices".
...
...
@@ -388,6 +388,10 @@ class PathSamplingCell(BaseSuperNetModule):
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
"""
Mutate only handles cells of specific configurations (e.g., with loose end).
Fallback to the default mutate if the cell is not handled here.
"""
if
isinstance
(
module
,
Cell
):
op_factory
=
None
# not all the cells need to be replaced
if
module
.
op_candidates_factory
is
not
None
:
...
...
test/algo/nas/test_oneshot.py
View file @
f77db747
...
...
@@ -5,6 +5,7 @@ import pytorch_lightning as pl
import
pytest
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
from
torch
import
nn
from
torch.utils.data
import
Dataset
,
RandomSampler
import
nni
...
...
@@ -13,7 +14,11 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
Regression
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ValueChoice
from
nni.retiarii.oneshot.pytorch
import
DartsLightningModule
from
nni.retiarii.strategy
import
BaseStrategy
from
pytorch_lightning
import
LightningModule
,
Trainer
from
.test_oneshot_utils
import
RandomDataset
pytestmark
=
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
...
...
@@ -338,17 +343,49 @@ def test_gumbel_darts():
_test_strategy
(
strategy
.
GumbelDARTS
())
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--exp'
,
type
=
str
,
default
=
'all'
,
metavar
=
'E'
,
help
=
'experiment to run, default = all'
)
args
=
parser
.
parse_args
()
def
test_optimizer_lr_scheduler
():
learning_rates
=
[]
if
args
.
exp
==
'all'
:
test_darts
()
test_proxyless
()
test_enas
()
test_random
()
test_gumbel_darts
()
else
:
globals
()[
f
'test_
{
args
.
exp
}
'
]()
class
CustomLightningModule
(
LightningModule
):
def
__init__
(
self
):
super
().
__init__
()
self
.
layer1
=
nn
.
Linear
(
32
,
2
)
self
.
layer2
=
nn
.
LayerChoice
([
nn
.
Linear
(
2
,
2
),
nn
.
Linear
(
2
,
2
,
bias
=
False
)])
def
forward
(
self
,
x
):
return
self
.
layer2
(
self
.
layer1
(
x
))
def
configure_optimizers
(
self
):
opt1
=
torch
.
optim
.
SGD
(
self
.
layer1
.
parameters
(),
lr
=
0.1
)
opt2
=
torch
.
optim
.
Adam
(
self
.
layer2
.
parameters
(),
lr
=
0.2
)
return
[
opt1
,
opt2
],
[
torch
.
optim
.
lr_scheduler
.
StepLR
(
opt1
,
step_size
=
2
,
gamma
=
0.1
)]
def
training_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'train_loss'
,
loss
)
return
{
'loss'
:
loss
}
def
on_train_epoch_start
(
self
)
->
None
:
learning_rates
.
append
(
self
.
optimizers
()[
0
].
param_groups
[
0
][
'lr'
])
def
validation_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'valid_loss'
,
loss
)
def
test_step
(
self
,
batch
,
batch_idx
):
loss
=
self
(
batch
).
sum
()
self
.
log
(
'test_loss'
,
loss
)
train_data
=
RandomDataset
(
32
,
32
)
valid_data
=
RandomDataset
(
32
,
16
)
model
=
CustomLightningModule
()
darts_module
=
DartsLightningModule
(
model
,
gradient_clip_val
=
5
)
trainer
=
Trainer
(
max_epochs
=
10
)
trainer
.
fit
(
darts_module
,
dict
(
train
=
DataLoader
(
train_data
,
batch_size
=
8
),
val
=
DataLoader
(
valid_data
,
batch_size
=
8
))
)
assert
len
(
learning_rates
)
==
10
and
abs
(
learning_rates
[
0
]
-
0.1
)
<
1e-5
and
\
abs
(
learning_rates
[
2
]
-
0.01
)
<
1e-5
and
abs
(
learning_rates
[
-
1
]
-
1e-5
)
<
1e-6
test/algo/nas/test_oneshot_proxyless.py
0 → 100644
View file @
f77db747
import
torch
import
torch.nn
as
nn
from
nni.nas.hub.pytorch.nasbench201
import
OPS_WITH_STRIDE
from
nni.nas.oneshot.pytorch.supermodule.proxyless
import
ProxylessMixedLayer
,
ProxylessMixedInput
,
_iter_tensors
def
test_proxyless_bp
():
op
=
ProxylessMixedLayer
(
[(
name
,
value
(
3
,
3
,
1
))
for
name
,
value
in
OPS_WITH_STRIDE
.
items
()],
nn
.
Parameter
(
torch
.
randn
(
len
(
OPS_WITH_STRIDE
))),
nn
.
Softmax
(
-
1
),
'proxyless'
)
optimizer
=
torch
.
optim
.
SGD
(
op
.
parameters
(
arch
=
True
),
0.1
)
for
_
in
range
(
10
):
x
=
torch
.
randn
(
1
,
3
,
9
,
9
).
requires_grad_
()
op
.
resample
({})
y
=
op
(
x
).
sum
()
optimizer
.
zero_grad
()
y
.
backward
()
assert
op
.
_arch_alpha
.
grad
.
abs
().
sum
().
item
()
!=
0
def
test_proxyless_input
():
inp
=
ProxylessMixedInput
(
6
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
6
)),
nn
.
Softmax
(
-
1
),
'proxyless'
)
optimizer
=
torch
.
optim
.
SGD
(
inp
.
parameters
(
arch
=
True
),
0.1
)
for
_
in
range
(
10
):
x
=
[
torch
.
randn
(
1
,
3
,
9
,
9
).
requires_grad_
()
for
_
in
range
(
6
)]
inp
.
resample
({})
y
=
inp
(
x
).
sum
()
optimizer
.
zero_grad
()
y
.
backward
()
def
test_iter_tensors
():
a
=
(
torch
.
zeros
(
3
,
1
),
{
'a'
:
torch
.
zeros
(
5
,
1
),
'b'
:
torch
.
zeros
(
6
,
1
)},
[
torch
.
zeros
(
7
,
1
)])
ret
=
[]
for
x
in
_iter_tensors
(
a
):
ret
.
append
(
x
.
shape
[
0
])
assert
ret
==
[
3
,
5
,
6
,
7
]
class
MultiInputLayer
(
nn
.
Module
):
def
__init__
(
self
,
d
):
super
().
__init__
()
self
.
d
=
d
def
forward
(
self
,
q
,
k
,
v
=
None
,
mask
=
None
):
return
q
+
self
.
d
,
2
*
k
-
2
*
self
.
d
,
v
,
mask
def
test_proxyless_multi_input
():
op
=
ProxylessMixedLayer
(
[
(
'a'
,
MultiInputLayer
(
1
)),
(
'b'
,
MultiInputLayer
(
3
))
],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'proxyless'
)
optimizer
=
torch
.
optim
.
SGD
(
op
.
parameters
(
arch
=
True
),
0.1
)
for
retry
in
range
(
10
):
q
=
torch
.
randn
(
1
,
3
,
9
,
9
).
requires_grad_
()
k
=
torch
.
randn
(
1
,
3
,
9
,
8
).
requires_grad_
()
v
=
None
if
retry
<
5
else
torch
.
randn
(
1
,
3
,
9
,
7
).
requires_grad_
()
mask
=
None
if
retry
%
5
<
2
else
torch
.
randn
(
1
,
3
,
9
,
6
).
requires_grad_
()
op
.
resample
({})
y
=
op
(
q
,
k
,
v
,
mask
=
mask
)
y
=
y
[
0
].
sum
()
+
y
[
1
].
sum
()
optimizer
.
zero_grad
()
y
.
backward
()
assert
op
.
_arch_alpha
.
grad
.
abs
().
sum
().
item
()
!=
0
,
op
.
_arch_alpha
.
grad
test/algo/nas/test_oneshot_supermodules.py
View file @
f77db747
...
...
@@ -3,7 +3,7 @@ import pytest
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
Conv2d
,
BatchNorm2d
,
LayerNorm
,
Linear
,
MultiheadAttention
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
LayerChoice
,
Conv2d
,
BatchNorm2d
,
LayerNorm
,
Linear
,
MultiheadAttention
from
nni.retiarii.oneshot.pytorch.base_lightning
import
traverse_and_mutate_submodules
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
(
MixedOpDifferentiablePolicy
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
GumbelSoftmax
,
...
...
@@ -144,6 +144,16 @@ def test_differentiable_valuechoice():
assert
set
(
conv
.
export
({}).
keys
())
==
{
'123'
,
'456'
}
def
test_differentiable_layerchoice_dedup
():
layerchoice1
=
LayerChoice
([
Conv2d
(
3
,
3
,
3
),
Conv2d
(
3
,
3
,
3
)],
label
=
'a'
)
layerchoice2
=
LayerChoice
([
Conv2d
(
3
,
3
,
3
),
Conv2d
(
3
,
3
,
3
)],
label
=
'a'
)
memo
=
{}
DifferentiableMixedLayer
.
mutate
(
layerchoice1
,
'x'
,
memo
,
{})
DifferentiableMixedLayer
.
mutate
(
layerchoice2
,
'x'
,
memo
,
{})
assert
len
(
memo
)
==
1
and
'a'
in
memo
def
_mixed_operation_sampling_sanity_check
(
operation
,
memo
,
*
input
):
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
if
native_op
.
bound_type
==
type
(
operation
):
...
...
@@ -160,7 +170,9 @@ def _mixed_operation_differentiable_sanity_check(operation, *input):
mutate_op
=
native_op
.
mutate
(
operation
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
})
break
return
mutate_op
(
*
input
)
mutate_op
(
*
input
)
mutate_op
.
export
({})
mutate_op
.
export_probs
({})
def
test_mixed_linear
():
...
...
@@ -319,6 +331,9 @@ def test_differentiable_layer_input():
op
=
DifferentiableMixedLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
export
({})[
'eee'
]
in
[
'a'
,
'b'
]
probs
=
op
.
export_probs
({})
assert
len
(
probs
)
==
2
assert
abs
(
probs
[
'eee/a'
]
+
probs
[
'eee/b'
]
-
1
)
<
1e-4
assert
len
(
list
(
op
.
parameters
()))
==
3
with
pytest
.
raises
(
ValueError
):
...
...
@@ -328,6 +343,8 @@ def test_differentiable_layer_input():
input
=
DifferentiableMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
(
-
1
)
==
2
assert
len
(
input
.
export
({})[
'ddd'
])
==
2
assert
len
(
input
.
export_probs
({}))
==
5
assert
'ddd/3'
in
input
.
export_probs
({})
def
test_proxyless_layer_input
():
...
...
@@ -341,7 +358,8 @@ def test_proxyless_layer_input():
input
=
ProxylessMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
.
resample
({})[
'ddd'
]
in
list
(
range
(
5
))
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
()
==
torch
.
Size
([
4
,
2
])
assert
input
.
export
({})[
'ddd'
]
in
list
(
range
(
5
))
exported
=
input
.
export
({})[
'ddd'
]
assert
len
(
exported
)
==
2
and
all
(
e
in
list
(
range
(
5
))
for
e
in
exported
)
def
test_pathsampling_repeat
():
...
...
@@ -373,6 +391,7 @@ def test_differentiable_repeat():
assert
op
(
torch
.
randn
(
2
,
8
)).
size
()
==
torch
.
Size
([
2
,
16
])
sample
=
op
.
export
({})
assert
'ccc'
in
sample
and
sample
[
'ccc'
]
in
[
0
,
1
]
assert
sorted
(
op
.
export_probs
({}).
keys
())
==
[
'ccc/0'
,
'ccc/1'
]
class
TupleModule
(
nn
.
Module
):
def
__init__
(
self
,
num
):
...
...
@@ -452,11 +471,16 @@ def test_differentiable_cell():
result
.
update
(
module
.
export
(
memo
=
result
))
assert
len
(
result
)
==
model
.
cell
.
num_nodes
*
model
.
cell
.
num_ops_per_node
*
2
result_prob
=
{}
for
module
in
nas_modules
:
result_prob
.
update
(
module
.
export_probs
(
memo
=
result_prob
))
ctrl_params
=
[]
for
m
in
nas_modules
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
if
cell_cls
in
[
CellLooseEnd
,
CellOpFactory
]:
assert
len
(
ctrl_params
)
==
model
.
cell
.
num_nodes
*
(
model
.
cell
.
num_nodes
+
3
)
//
2
assert
len
(
result_prob
)
==
len
(
ctrl_params
)
*
2
# len(op_names) == 2
assert
isinstance
(
model
.
cell
,
DifferentiableMixedCell
)
else
:
assert
not
isinstance
(
model
.
cell
,
DifferentiableMixedCell
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment