Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
05c7d6e9
"vscode:/vscode.git/clone" did not exist on "0acb8586643082b7f084ea9d91104ce6bf6e05b5"
Unverified
Commit
05c7d6e9
authored
Apr 28, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 28, 2022
Browse files
Typehint for oneshot NAS (#4811)
parent
cbac2c5c
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
379 additions
and
254 deletions
+379
-254
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+69
-41
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+3
-0
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+19
-14
nni/retiarii/oneshot/pytorch/enas.py
nni/retiarii/oneshot/pytorch/enas.py
+23
-17
nni/retiarii/oneshot/pytorch/proxyless.py
nni/retiarii/oneshot/pytorch/proxyless.py
+3
-0
nni/retiarii/oneshot/pytorch/random.py
nni/retiarii/oneshot/pytorch/random.py
+5
-0
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+36
-22
nni/retiarii/oneshot/pytorch/strategy.py
nni/retiarii/oneshot/pytorch/strategy.py
+12
-8
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
+23
-20
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
+3
-2
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
...etiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
+22
-13
nni/retiarii/oneshot/pytorch/supermodule/base.py
nni/retiarii/oneshot/pytorch/supermodule/base.py
+11
-9
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
+27
-21
nni/retiarii/oneshot/pytorch/supermodule/operation.py
nni/retiarii/oneshot/pytorch/supermodule/operation.py
+59
-51
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
+8
-6
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
+10
-8
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+23
-11
pyrightconfig.json
pyrightconfig.json
+0
-1
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+10
-7
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+13
-3
No files found.
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
warnings
from
itertools
import
chain
from
typing
import
Dict
,
Callable
,
Lis
t
,
Union
,
Any
,
Tuple
from
typing
import
Callable
,
Any
,
Dic
t
,
Union
,
Tuple
,
List
,
cast
import
pytorch_lightning
as
pl
import
torch.optim
as
optim
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
...
...
@@ -24,8 +27,8 @@ MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[
def
traverse_and_mutate_submodules
(
root_module
:
nn
.
Module
,
hooks
:
L
ist
[
MutationHook
],
mutate_kwargs
:
D
ict
[
str
,
Any
],
topdown
:
bool
=
True
)
->
L
ist
[
BaseSuperNetModule
]:
root_module
:
nn
.
Module
,
hooks
:
l
ist
[
MutationHook
],
mutate_kwargs
:
d
ict
[
str
,
Any
],
topdown
:
bool
=
True
)
->
l
ist
[
BaseSuperNetModule
]:
"""
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
...
...
@@ -36,7 +39,7 @@ def traverse_and_mutate_submodules(
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
it's usually a ``pytorch_lightning.LightningModule``.
The mutation will be in-place on ``root_module``.
hooks :
L
ist[MutationHook]
hooks :
l
ist[MutationHook]
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict
...
...
@@ -47,7 +50,7 @@ def traverse_and_mutate_submodules(
Returns
----------
modules :
D
ict[str, nn.Module]
modules :
d
ict[str, nn.Module]
The replace result.
"""
memo
=
{}
...
...
@@ -101,7 +104,7 @@ def traverse_and_mutate_submodules(
return
module_list
def
no_default_hook
(
module
:
nn
.
Module
,
name
:
str
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
bool
:
def
no_default_hook
(
module
:
nn
.
Module
,
name
:
str
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
bool
:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet
...
...
@@ -125,7 +128,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
if
is_traceable
(
module
):
# check whether there is a value-choice in its arguments
has_valuechoice
=
False
for
arg
in
chain
(
module
.
trace_args
,
module
.
trace_kwargs
.
values
()):
for
arg
in
chain
(
cast
(
list
,
module
.
trace_args
),
cast
(
dict
,
module
.
trace_kwargs
)
.
values
()):
if
isinstance
(
arg
,
ValueChoiceX
):
has_valuechoice
=
True
break
...
...
@@ -139,7 +142,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
class
BaseOneShotLightningModule
(
pl
.
LightningModule
):
_mutation_hooks_note
=
"""mutation_hooks :
L
ist[MutationHook]
_mutation_hooks_note
=
"""mutation_hooks :
l
ist[MutationHook]
Mutation hooks are callable that inputs an Module and returns a :class:`BaseSuperNetModule`.
They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
...
...
@@ -194,36 +197,40 @@ class BaseOneShotLightningModule(pl.LightningModule):
Attributes
----------
nas_modules :
L
ist[BaseSuperNetModule]
nas_modules :
l
ist[BaseSuperNetModule]
Modules that have been mutated, which the search algorithms should care about.
Parameters
----------
"""
+
_inner_module_note
+
_mutation_hooks_note
automatic_optimization
=
False
trainer
:
pl
.
Trainer
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
False
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Override this to define class-default mutation hooks."""
return
[
no_default_hook
]
def
mutate_kwargs
(
self
)
->
D
ict
[
str
,
Any
]:
def
mutate_kwargs
(
self
)
->
d
ict
[
str
,
Any
]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return
{}
def
__init__
(
self
,
base_
model
:
pl
.
LightningModule
,
mutation_hooks
:
L
ist
[
MutationHook
]
=
None
):
def
__init__
(
self
,
model
:
pl
.
LightningModule
,
mutation_hooks
:
l
ist
[
MutationHook
]
|
None
=
None
):
super
().
__init__
()
assert
isinstance
(
base_
model
,
pl
.
LightningModule
)
self
.
model
=
base_
model
assert
isinstance
(
model
,
pl
.
LightningModule
)
self
.
model
=
model
# append the default hooks
mutation_hooks
=
(
mutation_hooks
or
[])
+
self
.
default_mutation_hooks
()
# traverse the model, calling hooks on every submodule
self
.
nas_modules
:
L
ist
[
BaseSuperNetModule
]
=
traverse_and_mutate_submodules
(
self
.
nas_modules
:
l
ist
[
BaseSuperNetModule
]
=
traverse_and_mutate_submodules
(
self
.
model
,
mutation_hooks
,
self
.
mutate_kwargs
(),
topdown
=
True
)
def
search_space_spec
(
self
)
->
D
ict
[
str
,
ParameterSpec
]:
def
search_space_spec
(
self
)
->
d
ict
[
str
,
ParameterSpec
]:
"""Get the search space specification from ``nas_module``.
Returns
...
...
@@ -236,7 +243,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
result
.
update
(
module
.
search_space_spec
())
return
result
def
resample
(
self
)
->
D
ict
[
str
,
Any
]:
def
resample
(
self
)
->
d
ict
[
str
,
Any
]:
"""Trigger the resample for each ``nas_module``.
Sometimes (e.g., in differentiable cases), it does nothing.
...
...
@@ -250,7 +257,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
result
.
update
(
module
.
resample
(
memo
=
result
))
return
result
def
export
(
self
)
->
D
ict
[
str
,
Any
]:
def
export
(
self
)
->
d
ict
[
str
,
Any
]:
"""
Export the NAS result, ideally the best choice of each ``nas_module``.
You may implement an ``export`` method for your customized ``nas_module``.
...
...
@@ -291,12 +298,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
arc_optimizers
=
[
arc_optimizers
]
self
.
arc_optim_count
=
len
(
arc_optimizers
)
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
self
.
trainer
.
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
lr_schedulers
=
self
.
trainer
.
_configure_schedulers
(
lr_schedulers
,
monitor
,
not
self
.
automatic_optimization
)
try
:
# above v1.6
from
pytorch_lightning.core.optimizer
import
(
# pylint: disable=import-error
_configure_optimizers
,
# type: ignore
_configure_schedulers_automatic_opt
,
# type: ignore
_configure_schedulers_manual_opt
# type: ignore
)
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
# type: ignore
lr_schedulers
=
(
_configure_schedulers_automatic_opt
(
lr_schedulers
,
monitor
)
if
self
.
automatic_optimization
else
_configure_schedulers_manual_opt
(
lr_schedulers
)
)
except
ImportError
:
# under v1.5
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
self
.
trainer
.
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
# type: ignore
lr_schedulers
=
self
.
trainer
.
_configure_schedulers
(
lr_schedulers
,
monitor
,
not
self
.
automatic_optimization
)
if
any
(
sch
[
"scheduler"
].
optimizer
not
in
w_optimizers
for
sch
in
lr_schedulers
):
raise
Exception
(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
...
...
@@ -312,7 +337,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
self
.
model
.
trainer
=
self
.
trainer
# type: ignore
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_train_start
()
...
...
@@ -359,7 +384,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
Returns
----------
arc_optimizers :
L
ist[Optimizer], Optimizer
arc_optimizers :
l
ist[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return
None
...
...
@@ -376,9 +401,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
"""
def
apply
(
lr_scheduler
):
# single scheduler is called every epoch
if
isinstance
(
lr_scheduler
,
_LRScheduler
)
and
\
self
.
trainer
.
is_last_batch
:
lr_scheduler
s
.
step
()
if
isinstance
(
lr_scheduler
,
_LRScheduler
)
:
if
self
.
trainer
.
is_last_batch
:
lr_scheduler
.
step
()
# lr_scheduler_config is called as configured
elif
isinstance
(
lr_scheduler
,
dict
):
interval
=
lr_scheduler
[
'interval'
]
...
...
@@ -392,7 +417,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
self
.
trainer
.
is_last_batch
and
(
self
.
trainer
.
current_epoch
+
1
)
%
frequency
==
0
):
lr_scheduler
.
step
()
lr_scheduler
[
'scheduler'
]
.
step
()
lr_schedulers
=
self
.
lr_schedulers
()
...
...
@@ -402,7 +427,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
else
:
apply
(
lr_schedulers
)
def
call_
user
_optimizers
(
self
,
method
):
def
call_
weight
_optimizers
(
self
,
method
):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
...
...
@@ -418,10 +443,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
elif
method
==
'zero_grad'
:
optimizer
.
zero_grad
()
optimizers
=
self
.
user
_optimizers
optimizers
=
self
.
weight
_optimizers
()
if
optimizers
is
None
:
return
assert
isinstance
(
optimizers
,
list
),
'Did you forget to set use_pl_optimizers to true?'
if
len
(
self
.
frequencies
)
>
0
:
self
.
cur_optimizer_step
+=
1
if
self
.
frequencies
[
self
.
cur_optimizer_index
]
==
self
.
cur_optimizer_step
:
...
...
@@ -434,14 +461,13 @@ class BaseOneShotLightningModule(pl.LightningModule):
for
optimizer
in
optimizers
:
apply_method
(
optimizer
,
method
)
@
property
def
architecture_optimizers
(
self
):
def
architecture_optimizers
(
self
)
->
list
[
Optimizer
]
|
Optimizer
|
None
:
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in ``training_step``.
Returns
----------
opts :
L
ist[Optimizer], Optimizer, None
opts :
l
ist[Optimizer], Optimizer, None
Architecture optimizers defined in ``configure_architecture_optimizers``. This will be None if there is no
architecture optimizers.
"""
...
...
@@ -450,28 +476,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
# pylint: disable=unsubscriptable-object
arc_opts
=
opts
[:
self
.
arc_optim_count
]
if
len
(
arc_opts
)
==
1
:
arc_opts
=
arc_opts
[
0
]
return
arc_opts
return
cast
(
Optimizer
,
arc_opts
[
0
]
)
return
cast
(
List
[
Optimizer
],
arc_opts
)
# If there is only 1 optimizer and it is the architecture optimizer
if
self
.
arc_optim_count
==
1
:
return
opts
return
cast
(
Union
[
List
[
Optimizer
],
Optimizer
],
opts
)
return
None
@
property
def
user_optimizers
(
self
):
def
weight_optimizers
(
self
)
->
list
[
Optimizer
]
|
Optimizer
|
None
:
"""
Get user optimizers from all optimizers. Use this to get user optimizers in ``training_step``.
Returns
----------
opts :
L
ist[Optimizer], Optimizer, None
opts :
l
ist[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
return
opts
[
self
.
arc_optim_count
:]
return
cast
(
List
[
Optimizer
],
opts
[
self
.
arc_optim_count
:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer
if
self
.
arc_optim_count
==
0
:
return
opts
return
cast
(
Union
[
List
[
Optimizer
],
Optimizer
],
opts
)
return
None
nni/retiarii/oneshot/pytorch/darts.py
View file @
05c7d6e9
...
...
@@ -5,6 +5,7 @@
import
copy
import
logging
import
warnings
from
collections
import
OrderedDict
import
torch
...
...
@@ -111,6 +112,8 @@ class DartsTrainer(BaseOneShotTrainer):
learning_rate
=
2.5E-3
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
False
):
warnings
.
warn
(
'DartsTrainer is deprecated. Please use strategy.DARTS instead.'
,
DeprecationWarning
)
self
.
model
=
model
self
.
loss
=
loss
self
.
metrics
=
metrics
...
...
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
05c7d6e9
...
...
@@ -3,9 +3,11 @@
"""Experimental version of differentiable one-shot implementation."""
from
typing
import
List
from
__future__
import
annotations
import
pytorch_lightning
as
pl
import
torch
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.supermodule.differentiable
import
(
...
...
@@ -45,7 +47,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Replace modules with differentiable versions"""
hooks
=
[
DifferentiableMixedLayer
.
mutate
,
...
...
@@ -62,14 +64,16 @@ class DartsLightningModule(BaseOneShotLightningModule):
}
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
mutation_hooks
:
L
ist
[
MutationHook
]
=
None
,
mutation_hooks
:
l
ist
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0E-4
):
self
.
arc_learning_rate
=
arc_learning_rate
super
().
__init__
(
inner_module
,
mutation_hooks
=
mutation_hooks
)
def
training_step
(
self
,
batch
,
batch_idx
):
# grad manually
arc_optim
=
self
.
architecture_optimizers
arc_optim
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_optim
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_optim to be a single Optimizer, but found:
{
arc_optim
}
'
)
# The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch
,
val_batch
=
batch
...
...
@@ -88,12 +92,12 @@ class DartsLightningModule(BaseOneShotLightningModule):
# phase 2: model step
self
.
resample
()
self
.
call_
user
_optimizers
(
'zero_grad'
)
self
.
call_
weight
_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
call_
user
_optimizers
(
'step'
)
self
.
call_
weight
_optimizers
(
'step'
)
self
.
call_lr_schedulers
(
batch_idx
)
...
...
@@ -107,7 +111,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params
=
[]
for
m
in
self
.
nas_modules
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
# type: ignore
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
set
(
ctrl_params
)),
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
return
ctrl_optim
...
...
@@ -135,7 +139,7 @@ class ProxylessLightningModule(DartsLightningModule):
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Replace modules with gumbel-differentiable versions"""
hooks
=
[
ProxylessMixedLayer
.
mutate
,
...
...
@@ -147,7 +151,7 @@ class ProxylessLightningModule(DartsLightningModule):
def
finalize_grad
(
self
):
for
m
in
self
.
nas_modules
:
m
.
finalize_grad
()
m
.
finalize_grad
()
# type: ignore
class
GumbelDartsLightningModule
(
DartsLightningModule
):
...
...
@@ -177,7 +181,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Replace modules with gumbel-differentiable versions"""
hooks
=
[
DifferentiableMixedLayer
.
mutate
,
...
...
@@ -195,7 +199,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
}
def
__init__
(
self
,
inner_module
,
mutation_hooks
:
L
ist
[
MutationHook
]
=
None
,
mutation_hooks
:
l
ist
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0e-4
,
gumbel_temperature
:
float
=
1.
,
use_temp_anneal
:
bool
=
False
,
...
...
@@ -206,12 +210,13 @@ class GumbelDartsLightningModule(DartsLightningModule):
self
.
use_temp_anneal
=
use_temp_anneal
self
.
min_temp
=
min_temp
def
on_epoch_
start
(
self
):
def
on_
train_
epoch_
end
(
self
):
if
self
.
use_temp_anneal
:
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
for
module
in
self
.
nas_modules
:
module
.
_softmax
.
temp
=
self
.
temp
if
hasattr
(
module
,
'_softmax'
):
module
.
_softmax
.
temp
=
self
.
temp
# type: ignore
return
self
.
model
.
on_epoch_
start
()
return
self
.
model
.
on_
train_
epoch_
end
()
nni/retiarii/oneshot/pytorch/enas.py
View file @
05c7d6e9
...
...
@@ -2,10 +2,14 @@
# Licensed under the MIT license.
import
logging
import
warnings
from
typing
import
cast
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
torch.utils.data
import
SubsetRandomSampler
,
DataLoader
from
..interface
import
BaseOneShotTrainer
from
.random
import
PathSamplingLayerChoice
,
PathSamplingInputChoice
...
...
@@ -113,9 +117,9 @@ class ReinforceController(nn.Module):
self
.
_h
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
dtype
=
self
.
_inputs
.
dtype
,
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
self
.
sample_log_prob
=
0
self
.
sample_entropy
=
0
self
.
sample_skip_penalty
=
0
self
.
sample_log_prob
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
self
.
sample_entropy
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
self
.
sample_skip_penalty
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
def
_lstm_next_step
(
self
):
self
.
_h
,
self
.
_c
=
self
.
lstm
(
self
.
_inputs
,
(
self
.
_h
,
self
.
_c
))
...
...
@@ -143,7 +147,7 @@ class ReinforceController(nn.Module):
if
sampled
.
sum
().
item
():
self
.
_inputs
=
(
torch
.
sum
(
self
.
embedding
[
field
.
name
](
sampled
.
view
(
-
1
)),
0
)
/
(
1.
+
torch
.
sum
(
sampled
))).
unsqueeze
(
0
)
else
:
self
.
_inputs
=
torch
.
zeros
(
1
,
self
.
lstm_size
,
device
=
self
.
embedding
[
field
.
name
].
weight
.
device
)
self
.
_inputs
=
torch
.
zeros
(
1
,
self
.
lstm_size
,
device
=
self
.
embedding
[
field
.
name
].
weight
.
device
)
# type: ignore
sampled
=
sampled
.
detach
().
cpu
().
numpy
().
tolist
()
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
...
...
@@ -205,6 +209,8 @@ class EnasTrainer(BaseOneShotTrainer):
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
grad_clip
=
5.
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
ctrl_lr
=
0.00035
,
ctrl_steps_aggregate
=
20
,
ctrl_kwargs
=
None
):
warnings
.
warn
(
'EnasTrainer is deprecated. Please use strategy.ENAS instead.'
,
DeprecationWarning
)
self
.
model
=
model
self
.
loss
=
loss
self
.
metrics
=
metrics
...
...
@@ -246,16 +252,16 @@ class EnasTrainer(BaseOneShotTrainer):
n_train
=
len
(
self
.
dataset
)
split
=
n_train
//
2
indices
=
list
(
range
(
n_train
))
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
-
split
])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
-
split
:])
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size
,
sampler
=
train_sampler
,
num_workers
=
self
.
workers
)
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size
,
sampler
=
valid_sampler
,
num_workers
=
self
.
workers
)
train_sampler
=
SubsetRandomSampler
(
indices
[:
-
split
])
valid_sampler
=
SubsetRandomSampler
(
indices
[
-
split
:])
self
.
train_loader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size
,
sampler
=
train_sampler
,
num_workers
=
self
.
workers
)
self
.
valid_loader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size
,
sampler
=
valid_sampler
,
num_workers
=
self
.
workers
)
def
_train_model
(
self
,
epoch
):
self
.
model
.
train
()
...
...
@@ -294,15 +300,15 @@ class EnasTrainer(BaseOneShotTrainer):
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
:
reward
+=
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
reward
+=
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
loss
=
self
.
controller
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
loss
+=
self
.
skip_weight
*
self
.
controller
.
sample_skip_penalty
metrics
[
'reward'
]
=
reward
metrics
[
'loss'
]
=
loss
.
item
()
metrics
[
'ent'
]
=
self
.
controller
.
sample_entropy
.
item
()
metrics
[
'log_prob'
]
=
self
.
controller
.
sample_log_prob
.
item
()
metrics
[
'ent'
]
=
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
metrics
[
'log_prob'
]
=
self
.
controller
.
sample_log_prob
.
item
()
# type: ignore
metrics
[
'baseline'
]
=
self
.
baseline
metrics
[
'skip'
]
=
self
.
controller
.
sample_skip_penalty
...
...
nni/retiarii/oneshot/pytorch/proxyless.py
View file @
05c7d6e9
...
...
@@ -4,6 +4,7 @@
# type: ignore
import
logging
import
warnings
import
torch
import
torch.nn
as
nn
...
...
@@ -230,6 +231,8 @@ class ProxylessTrainer(BaseOneShotTrainer):
grad_reg_loss_type
=
None
,
grad_reg_loss_params
=
None
,
applied_hardware
=
None
,
dummy_input
=
(
1
,
3
,
224
,
224
),
ref_latency
=
65.0
):
warnings
.
warn
(
'ProxylessTrainer is deprecated. Please use strategy.Proxyless instead.'
,
DeprecationWarning
)
self
.
model
=
model
self
.
loss
=
loss
self
.
metrics
=
metrics
...
...
nni/retiarii/oneshot/pytorch/random.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# type: ignore
import
logging
import
random
import
warnings
import
torch
import
torch.nn
as
nn
...
...
@@ -122,6 +125,8 @@ class SinglePathTrainer(BaseOneShotTrainer):
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
):
warnings
.
warn
(
'SinglePathTrainer is deprecated. Please use strategy.RandomOneShot instead.'
,
DeprecationWarning
)
self
.
model
=
model
self
.
loss
=
loss
self
.
metrics
=
metrics
...
...
nni/retiarii/oneshot/pytorch/sampling.py
View file @
05c7d6e9
...
...
@@ -3,7 +3,8 @@
"""Experimental version of sampling-based one-shot implementation."""
from
typing
import
Dict
,
Any
,
List
from
__future__
import
annotations
from
typing
import
Any
import
pytorch_lightning
as
pl
import
torch
...
...
@@ -33,9 +34,11 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
)
# turn on automatic optimization because nothing interesting is going on here.
automatic_optimization
=
True
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
True
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Replace modules with differentiable versions"""
hooks
=
[
PathSamplingLayer
.
mutate
,
...
...
@@ -80,6 +83,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
If there are multiple, it will find the metric with key name ``reward_metric_name``,
which is "default" by default.
Otherwise it raises an exception indicating multiple metrics are found.
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
__doc__
=
_enas_note
.
format
(
...
...
@@ -87,23 +96,26 @@ class EnasLightningModule(RandomSamplingLightningModule):
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
automatic_optimization
=
False
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
False
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
*
,
ctrl_kwargs
:
D
ict
[
str
,
Any
]
=
None
,
ctrl_kwargs
:
d
ict
[
str
,
Any
]
|
None
=
None
,
entropy_weight
:
float
=
1e-4
,
skip_weight
:
float
=
.
8
,
baseline_decay
:
float
=
.
999
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_grad_clip
:
float
=
0
,
mutation_hooks
:
List
[
MutationHook
]
=
None
):
reward_metric_name
:
str
|
None
=
None
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
):
super
().
__init__
(
inner_module
,
mutation_hooks
)
# convert parameter spec to legacy ReinforceField
# this part will be refactored
self
.
nas_fields
:
L
ist
[
ReinforceField
]
=
[]
self
.
nas_fields
:
l
ist
[
ReinforceField
]
=
[]
for
name
,
param_spec
in
self
.
search_space_spec
().
items
():
if
param_spec
.
chosen_size
not
in
(
1
,
None
):
raise
ValueError
(
'ENAS does not support n_chosen to be values other than 1 or None.'
)
...
...
@@ -116,6 +128,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
self
.
baseline
=
0.
self
.
ctrl_steps_aggregate
=
ctrl_steps_aggregate
self
.
ctrl_grad_clip
=
ctrl_grad_clip
self
.
reward_metric_name
=
reward_metric_name
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
...
...
@@ -127,34 +140,35 @@ class EnasLightningModule(RandomSamplingLightningModule):
if
source
==
'train'
:
# step 1: train model params
self
.
resample
()
self
.
call_
user
_optimizers
(
'zero_grad'
)
self
.
call_
weight
_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
call_
user
_optimizers
(
'step'
)
self
.
call_
weight
_optimizers
(
'step'
)
return
loss_and_metrics
if
source
==
'val'
:
# step 2: train ENAS agent
x
,
y
=
batch
arc_opt
=
self
.
architecture_optimizers
arc_opt
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
self
.
resample
()
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
self
.
model
.
validation_step
(
batch
,
batch_idx
)
# use the default metric of self.model as reward function
if
len
(
self
.
model
.
metrics
)
==
1
:
_
,
metric
=
next
(
iter
(
self
.
model
.
metrics
.
items
()))
if
len
(
self
.
trainer
.
callback_
metrics
)
==
1
:
_
,
metric
=
next
(
iter
(
self
.
trainer
.
callback_
metrics
.
items
()))
else
:
if
'default'
not
in
self
.
model
.
metrics
.
keys
():
raise
KeyError
(
'model.metrics should contain a ``default`` key when'
'there are multiple metrics'
)
metric
=
self
.
model
.
metrics
[
'default'
]
metric_name
=
self
.
reward_metric_name
or
'default'
if
metric_name
not
in
self
.
trainer
.
callback_metrics
:
raise
KeyError
(
f
'Model reported metrics should contain a ``
{
metric_name
}
`` key but '
f
'found multiple metrics without default:
{
self
.
trainer
.
callback_metrics
.
keys
()
}
'
)
metric
=
self
.
trainer
.
callback_metrics
[
metric_name
]
reward
:
float
=
metric
.
item
()
reward
=
metric
(
logits
,
y
)
if
self
.
entropy_weight
:
reward
=
reward
+
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
reward
=
reward
+
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
rnn_step_loss
=
self
.
controller
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
...
...
@@ -183,7 +197,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
with
torch
.
no_grad
():
return
self
.
_interpret_controller_sampling_result
(
self
.
controller
.
resample
())
def
_interpret_controller_sampling_result
(
self
,
sample
:
D
ict
[
str
,
int
])
->
D
ict
[
str
,
Any
]:
def
_interpret_controller_sampling_result
(
self
,
sample
:
d
ict
[
str
,
int
])
->
d
ict
[
str
,
Any
]:
"""Convert ``{label: index}`` to ``{label: name}``"""
space_spec
=
self
.
search_space_spec
()
for
key
in
list
(
sample
.
keys
()):
...
...
nni/retiarii/oneshot/pytorch/strategy.py
View file @
05c7d6e9
...
...
@@ -10,8 +10,10 @@ For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to b
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""
from
__future__
import
annotations
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Type
,
Union
,
Tupl
e
from
typing
import
Any
,
Typ
e
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
...
...
@@ -33,10 +35,10 @@ class OneShotStrategy(BaseStrategy):
self
.
oneshot_module
=
oneshot_module
self
.
oneshot_kwargs
=
kwargs
self
.
model
:
Optional
[
BaseOneShotLightningModule
]
=
None
self
.
model
:
BaseOneShotLightningModule
|
None
=
None
def
_get_dataloader
(
self
,
train_dataloader
:
DataLoader
,
val_dataloaders
:
DataLoader
)
\
->
Union
[
DataLoader
,
T
uple
[
DataLoader
,
DataLoader
]
]
:
def
_get_dataloader
(
self
,
train_dataloader
:
DataLoader
,
val_dataloaders
:
DataLoader
|
list
[
DataLoader
]
)
\
->
DataLoader
|
t
uple
[
DataLoader
,
DataLoader
]:
"""
One-shot strategy typically requires a customized dataloader.
...
...
@@ -51,9 +53,9 @@ class OneShotStrategy(BaseStrategy):
_reason
=
'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
py_model
:
nn
.
Module
=
base_model
.
python_object
if
not
isinstance
(
py_model
,
nn
.
Module
):
if
not
isinstance
(
base_model
.
python_object
,
nn
.
Module
):
raise
TypeError
(
'Model is not a nn.Module. '
+
_reason
)
py_model
:
nn
.
Module
=
base_model
.
python_object
if
applied_mutators
:
raise
ValueError
(
'Mutator is not empty. '
+
_reason
)
...
...
@@ -64,8 +66,10 @@ class OneShotStrategy(BaseStrategy):
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
.
set_model
(
py_model
)
self
.
model
:
BaseOneShotLightningModule
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
self
.
model
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
evaluator
:
Lightning
=
base_model
.
evaluator
if
evaluator
.
train_dataloader
is
None
or
evaluator
.
val_dataloaders
is
None
:
raise
TypeError
(
'Train or val dataloader is not set.'
)
dataloader
=
self
.
_get_dataloader
(
evaluator
.
train_dataloader
,
evaluator
.
val_dataloaders
)
if
isinstance
(
dataloader
,
tuple
):
dataloader
,
val_loader
=
dataloader
...
...
@@ -73,7 +77,7 @@ class OneShotStrategy(BaseStrategy):
else
:
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
)
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
L
ist
[
Any
]:
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
l
ist
[
Any
]:
if
self
.
model
is
None
:
raise
RuntimeError
(
'One-shot strategy needs to be run before export.'
)
if
top_k
!=
1
:
...
...
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
View file @
05c7d6e9
...
...
@@ -26,8 +26,10 @@ The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor.
"""
from
__future__
import
annotations
import
operator
from
typing
import
Tuple
,
Union
,
List
,
Dict
,
Callable
,
Optional
,
Iterator
,
TypeVar
,
Any
,
Generic
,
cast
from
typing
import
Callable
,
Iterator
,
TypeVar
,
Any
,
Optional
,
Tuple
,
Union
,
List
,
Dict
,
Generic
,
cast
import
numpy
as
np
import
torch
...
...
@@ -58,8 +60,8 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
for
i
in
range
(
len
(
slice_
)):
if
isinstance
(
slice_
[
i
],
list
):
# convert list of slices to mask
mask
=
np
.
zeros
(
shape
[
i
],
dtype
=
np
.
bool
)
for
sl
in
slice_
[
i
]:
mask
=
np
.
zeros
(
shape
[
i
],
dtype
=
np
.
bool
)
# type: ignore
for
sl
in
cast
(
List
[
slice
],
slice_
[
i
]
)
:
mask
[
sl
]
=
1
result
.
append
(
mask
)
else
:
...
...
@@ -67,7 +69,7 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
return
tuple
(
result
)
def
_slice_weight
(
weight
:
T
,
slice_
:
Union
[
multidim_slice
,
L
ist
[
T
uple
[
multidim_slice
,
float
]]
]
)
->
T
:
def
_slice_weight
(
weight
:
T
,
slice_
:
multidim_slice
|
l
ist
[
t
uple
[
multidim_slice
,
float
]])
->
T
:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
...
...
@@ -84,27 +86,27 @@ def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_s
# create a mask with weight w
with
torch
.
no_grad
():
mask
=
zeros_like
(
weight
)
mask
[
_eliminate_list_slice
(
weight
.
shape
,
sl
)]
=
1
mask
[
_eliminate_list_slice
(
weight
.
shape
,
sl
)]
=
1
# type: ignore
# track gradients here
masks
.
append
(
(
mask
*
wt
)
)
masks
.
append
(
mask
*
wt
)
# type: ignore
masks
=
sum
(
masks
)
return
masks
*
weight
return
masks
*
weight
# type: ignore
else
:
# for unweighted case, we slice it directly.
def
_do_slice
(
arr
,
slice_
):
return
arr
[
_eliminate_list_slice
(
arr
.
shape
,
slice_
)]
return
arr
[
_eliminate_list_slice
(
arr
.
shape
,
slice_
)]
# type: ignore
# sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr
=
np
.
zeros
(
weight
.
shape
,
dtype
=
np
.
bool
)
no_effect
=
_do_slice
(
dummy_arr
,
slice_
).
shape
==
dummy_arr
.
shape
dummy_arr
=
np
.
zeros
(
weight
.
shape
,
dtype
=
np
.
bool
)
# type: ignore
no_effect
=
cast
(
Any
,
_do_slice
(
dummy_arr
,
slice_
)
)
.
shape
==
dummy_arr
.
shape
if
no_effect
:
return
weight
...
...
@@ -128,14 +130,14 @@ class Slicable(Generic[T]):
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
self
.
weight
=
weight
def
__getitem__
(
self
,
index
:
Union
[
slice_type
,
multidim_slice
]
)
->
T
:
def
__getitem__
(
self
,
index
:
slice_type
|
multidim_slice
)
->
T
:
if
not
isinstance
(
index
,
tuple
):
index
=
(
index
,
)
index
=
cast
(
multidim_slice
,
index
)
# Get the dict value in index's leafs
# There can be at most one dict
leaf_dict
:
Optional
[
D
ict
[
int
,
float
]
]
=
None
leaf_dict
:
d
ict
[
int
,
float
]
|
None
=
None
for
maybe_weighted
in
_iterate_over_multidim_slice
(
index
):
for
d
in
maybe_weighted
.
leaf_values
():
if
isinstance
(
d
,
dict
):
...
...
@@ -166,10 +168,10 @@ class MaybeWeighted:
"""
def
__init__
(
self
,
value
:
Optional
[
int_or_int_dict
]
=
None
,
*
,
lhs
:
Optional
[
Union
[
'MaybeWeighted'
,
int
]]
=
None
,
rhs
:
Optional
[
Union
[
'MaybeWeighted'
,
int
]]
=
None
,
operation
:
Optional
[
Callable
[[
int
,
int
],
int
]]
=
None
):
value
:
int_or_int_dict
|
None
=
None
,
*
,
lhs
:
'MaybeWeighted'
|
int
|
None
=
None
,
rhs
:
'MaybeWeighted'
|
int
|
None
=
None
,
operation
:
Callable
[[
int
_or_int_dict
,
int_or_int_dict
],
int_or_int_dict
]
|
None
=
None
):
if
operation
is
None
:
if
not
isinstance
(
value
,
(
int
,
dict
)):
raise
TypeError
(
f
'Unsupported value type:
{
type
(
value
)
}
'
)
...
...
@@ -178,7 +180,7 @@ class MaybeWeighted:
self
.
rhs
=
rhs
self
.
operation
=
operation
def
leaf_values
(
self
)
->
Iterator
[
Dict
[
int
,
float
]
]:
def
leaf_values
(
self
)
->
Iterator
[
int_or_int_dict
]:
"""Iterate over values on leaf nodes."""
if
self
.
value
is
not
None
:
yield
self
.
value
...
...
@@ -188,7 +190,7 @@ class MaybeWeighted:
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
yield
from
self
.
rhs
.
leaf_values
()
def
evaluate
(
self
,
value_fn
:
_value_fn_type
=
None
)
->
int
:
def
evaluate
(
self
,
value_fn
:
_value_fn_type
=
None
)
->
int
_or_int_dict
:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
"""
...
...
@@ -200,11 +202,12 @@ class MaybeWeighted:
if
isinstance
(
self
.
lhs
,
MaybeWeighted
):
eval_lhs
=
self
.
lhs
.
evaluate
(
value_fn
)
else
:
eval_lhs
=
self
.
lhs
eval_lhs
=
cast
(
int
,
self
.
lhs
)
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
eval_rhs
=
self
.
rhs
.
evaluate
(
value_fn
)
else
:
eval_rhs
=
self
.
rhs
eval_rhs
=
cast
(
int
,
self
.
rhs
)
assert
self
.
operation
is
not
None
return
self
.
operation
(
eval_lhs
,
eval_rhs
)
def
__repr__
(
self
):
...
...
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
View file @
05c7d6e9
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# pylint: skip-file
# type: ignore
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item.
...
...
@@ -96,9 +97,9 @@ class DifferentiableSuperConv2d(nn.Conv2d):
----------
input_weight : Tensor
the weight to be weighted summed
masks :
L
ist[Tensor]
masks :
l
ist[Tensor]
weight masks.
thresholds :
L
ist[float]
thresholds :
l
ist[float]
thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight
...
...
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
View file @
05c7d6e9
...
...
@@ -4,19 +4,26 @@
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
from
__future__
import
annotations
import
itertools
from
typing
import
List
,
Any
,
Dict
,
Tuple
,
Optional
,
Union
from
typing
import
Any
,
TypeVar
,
List
,
cast
import
numpy
as
np
import
torch
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.retiarii.nn.pytorch.api
import
ChoiceOf
,
ValueChoiceX
Choice
=
Any
T
=
TypeVar
(
'T'
)
__all__
=
[
'dedup_inner_choices'
,
'evaluate_value_choice_with_dict'
,
'traverse_all_options'
]
def
dedup_inner_choices
(
value_choices
:
L
ist
[
ValueChoiceX
])
->
D
ict
[
str
,
ParameterSpec
]:
def
dedup_inner_choices
(
value_choices
:
l
ist
[
ValueChoiceX
])
->
d
ict
[
str
,
ParameterSpec
]:
"""Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``.
"""
...
...
@@ -33,7 +40,7 @@ def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, Paramete
return
result
def
evaluate_value_choice_with_dict
(
value_choice
:
Value
Choice
X
,
chosen
:
D
ict
[
str
,
Choice
])
->
Any
:
def
evaluate_value_choice_with_dict
(
value_choice
:
Choice
Of
[
T
]
,
chosen
:
d
ict
[
str
,
Choice
])
->
T
:
"""To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values,
...
...
@@ -56,8 +63,10 @@ def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str
return
value_choice
.
evaluate
(
choice_inner_values
)
def
traverse_all_options
(
value_choice
:
ValueChoiceX
,
weights
:
Optional
[
Dict
[
str
,
List
[
float
]]]
=
None
)
->
List
[
Union
[
Tuple
[
Any
,
float
],
Any
]]:
def
traverse_all_options
(
value_choice
:
ChoiceOf
[
T
],
weights
:
dict
[
str
,
list
[
float
]]
|
dict
[
str
,
np
.
ndarray
]
|
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
)
->
list
[
tuple
[
T
,
float
]]
|
list
[
T
]:
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
...
...
@@ -65,33 +74,33 @@ def traverse_all_options(value_choice: ValueChoiceX,
----------
value_choice : ValueChoiceX
The value choice to traverse.
weights : Optional[
D
ict[str,
L
ist[float]]], default = None
weights : Optional[
d
ict[str,
l
ist[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
-------
L
ist[Union[
T
uple[Any, float], Any]]
l
ist[Union[
t
uple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
"""
# get a dict of {label: list of tuple of choice and weight}
leafs
:
D
ict
[
str
,
L
ist
[
T
uple
[
Choice
,
float
]]]
=
{}
leafs
:
d
ict
[
str
,
l
ist
[
t
uple
[
T
,
float
]]]
=
{}
for
label
,
param_spec
in
dedup_inner_choices
([
value_choice
]).
items
():
if
weights
is
not
None
:
if
label
not
in
weights
:
raise
KeyError
(
f
'
{
value_choice
}
depends on a weight with key
{
label
}
, but not found in
{
weights
}
'
)
if
len
(
weights
[
label
])
!=
param_spec
.
size
:
raise
KeyError
(
f
'Expect weights with
{
label
}
to be of length
{
param_spec
.
size
}
, but
{
len
(
weights
[
label
])
}
found'
)
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
weights
[
label
]))
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
cast
(
List
[
float
],
weights
[
label
]))
)
else
:
# create a dummy weight of zero, in case that weights are not provided.
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
itertools
.
repeat
(
0.
,
param_spec
.
size
)))
# result is a dict from a option to its weight
result
:
D
ict
[
str
,
Optional
[
float
]
]
=
{}
result
:
d
ict
[
T
,
float
|
None
]
=
{}
labels
,
values
=
list
(
leafs
.
keys
()),
list
(
leafs
.
values
())
if
not
labels
:
...
...
@@ -126,6 +135,6 @@ def traverse_all_options(value_choice: ValueChoiceX,
result
[
eval_res
]
=
chosen_weight
if
weights
is
None
:
return
sorted
(
result
.
keys
())
return
sorted
(
result
.
keys
())
# type: ignore
else
:
return
sorted
(
result
.
items
())
return
sorted
(
result
.
items
())
# type: ignore
nni/retiarii/oneshot/pytorch/supermodule/base.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
__future__
import
annotations
from
typing
import
Any
import
torch.nn
as
nn
...
...
@@ -24,13 +26,13 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions.
"""
def
resample
(
self
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
resample
(
self
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""
Resample the super-net module.
Parameters
----------
memo :
D
ict[str, Any]
memo :
d
ict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns
...
...
@@ -40,19 +42,19 @@ class BaseSuperNetModule(nn.Module):
"""
raise
NotImplementedError
()
def
export
(
self
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
export
(
self
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
Parameters
----------
memo :
D
ict[str, Any]
memo :
d
ict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise
NotImplementedError
()
def
search_space_spec
(
self
)
->
D
ict
[
str
,
ParameterSpec
]:
def
search_space_spec
(
self
)
->
d
ict
[
str
,
ParameterSpec
]:
"""
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
...
...
@@ -64,8 +66,8 @@ class BaseSuperNetModule(nn.Module):
raise
NotImplementedError
()
@
classmethod
def
mutate
(
cls
,
module
:
nn
.
Module
,
name
:
str
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
\
Union
[
'BaseSuperNetModule'
,
bool
,
T
uple
[
'BaseSuperNetModule'
,
bool
]
]
:
def
mutate
(
cls
,
module
:
nn
.
Module
,
name
:
str
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
\
'BaseSuperNetModule'
|
bool
|
t
uple
[
'BaseSuperNetModule'
,
bool
]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on.
...
...
@@ -84,7 +86,7 @@ class BaseSuperNetModule(nn.Module):
Returns
-------
Union[BaseSuperNetModule, bool,
T
uple[BaseSuperNetModule, bool]]
Union[BaseSuperNetModule, bool,
t
uple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`nni.retiarii.oneshot.pytorch.base.BaseOneShotLightningModule` for details.
"""
...
...
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
functools
import
warnings
from
typing
import
List
,
Tuple
,
Optional
,
Dict
,
Any
,
Union
from
typing
import
Any
,
cast
import
torch
import
torch.nn
as
nn
...
...
@@ -21,7 +23,9 @@ from ._valuechoice_utils import traverse_all_options
class
GumbelSoftmax
(
nn
.
Softmax
):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
def
__init__
(
self
,
dim
:
Optional
[
int
]
=
-
1
)
->
None
:
dim
:
int
def
__init__
(
self
,
dim
:
int
=
-
1
)
->
None
:
super
().
__init__
(
dim
)
self
.
tau
=
1
self
.
hard
=
False
...
...
@@ -42,7 +46,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Parameters
----------
paths :
L
ist[
T
uple[str, nn.Module]]
paths :
l
ist[
t
uple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
Tensor that stores the "learnable" weights.
...
...
@@ -59,9 +63,9 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Name of the choice.
"""
_arch_parameter_names
:
L
ist
[
str
]
=
[
'_arch_alpha'
]
_arch_parameter_names
:
l
ist
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
paths
:
L
ist
[
T
uple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
paths
:
l
ist
[
t
uple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
self
.
op_names
=
[]
if
len
(
alpha
)
!=
len
(
paths
):
...
...
@@ -82,7 +86,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
"""Choose the operator with the maximum logit."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
torch
.
argmax
(
self
.
_arch_alpha
).
item
()]}
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
()
)
]}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
...
...
@@ -149,9 +153,9 @@ class DifferentiableMixedInput(BaseSuperNetModule):
Name of the choice.
"""
_arch_parameter_names
:
L
ist
[
str
]
=
[
'_arch_alpha'
]
_arch_parameter_names
:
l
ist
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
self
.
n_candidates
=
n_candidates
if
len
(
alpha
)
!=
n_candidates
:
...
...
@@ -240,9 +244,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
won't be optimized.
"""
_arch_parameter_names
:
L
ist
[
str
]
=
[
'_arch_alpha'
]
_arch_parameter_names
:
l
ist
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
None
:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation
.
_arch_alpha
=
nn
.
ParameterDict
()
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
...
...
@@ -254,20 +258,20 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
operation
.
_arch_alpha
[
name
]
=
alpha
operation
.
parameters
=
functools
.
partial
(
self
.
parameters
,
self
=
operation
)
# bind self
operation
.
named_parameters
=
functools
.
partial
(
self
.
named_parameters
,
self
=
operation
)
operation
.
parameters
=
functools
.
partial
(
self
.
parameters
,
module
=
operation
)
# bind self
operation
.
named_parameters
=
functools
.
partial
(
self
.
named_parameters
,
module
=
operation
)
operation
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
@
staticmethod
def
parameters
(
self
,
*
args
,
**
kwargs
):
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
def
parameters
(
module
,
*
args
,
**
kwargs
):
for
_
,
p
in
module
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
@
staticmethod
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
def
named_parameters
(
module
,
*
args
,
**
kwargs
):
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
(
self
.
__class__
,
self
).
named_parameters
(
*
args
,
**
kwargs
):
# pylint: disable=bad-super-call
for
name
,
p
in
super
(
module
.
__class__
,
module
).
named_parameters
(
*
args
,
**
kwargs
):
# pylint: disable=bad-super-call
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
...
...
@@ -275,22 +279,24 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if
not
arch
:
yield
name
,
p
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""Differentiable. Do nothing in resample."""
return
{}
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
result
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
name
in
result
:
continue
chosen_index
=
torch
.
argmax
(
operation
.
_arch_alpha
[
name
]).
item
()
chosen_index
=
int
(
torch
.
argmax
(
cast
(
dict
,
operation
.
_arch_alpha
)
[
name
]).
item
()
)
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
Union
[
D
ict
[
Any
,
float
]
,
Any
]
:
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
d
ict
[
Any
,
float
]
|
Any
:
if
name
in
operation
.
mutable_arguments
:
weights
=
{
label
:
operation
.
_softmax
(
alpha
)
for
label
,
alpha
in
operation
.
_arch_alpha
.
items
()}
weights
:
dict
[
str
,
torch
.
Tensor
]
=
{
label
:
cast
(
nn
.
Module
,
operation
.
_softmax
)(
alpha
)
for
label
,
alpha
in
cast
(
dict
,
operation
.
_arch_alpha
).
items
()
}
return
dict
(
traverse_all_options
(
operation
.
mutable_arguments
[
name
],
weights
=
weights
))
return
operation
.
init_arguments
[
name
]
nni/retiarii/oneshot/pytorch/supermodule/operation.py
View file @
05c7d6e9
...
...
@@ -6,9 +6,11 @@ Operations that support weight sharing at a fine-grained level,
which is commonly known as super-kernel (as in channel search), or weight entanglement.
"""
from
__future__
import
annotations
import
inspect
import
itertools
from
typing
import
Union
,
Tuple
,
Dict
,
List
,
Any
,
Type
,
Optional
,
TypeVar
,
cast
from
typing
import
Any
,
Type
,
TypeVar
,
cast
,
Union
,
Tuple
import
torch
import
torch.nn
as
nn
...
...
@@ -37,7 +39,7 @@ class MixedOperationSamplingPolicy:
One SamplingStrategy corresponds to one mixed operation.
"""
def
__init__
(
self
,
operation
:
'MixedOperation'
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
operation
:
'MixedOperation'
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
None
:
"""At init, the sampling policy can prepare basic parameters,
and store them in operation if they need back propagation.
...
...
@@ -47,11 +49,11 @@ class MixedOperationSamplingPolicy:
"""
pass
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.resample`."""
raise
NotImplementedError
()
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.export`."""
raise
NotImplementedError
()
...
...
@@ -90,7 +92,7 @@ class MixedOperation(BaseSuperNetModule):
"""
bound_type
:
Type
[
nn
.
Module
]
# defined in subclass
argument_list
:
L
ist
[
str
]
# defined in subclass
argument_list
:
l
ist
[
str
]
# defined in subclass
sampling_policy
:
MixedOperationSamplingPolicy
...
...
@@ -114,11 +116,11 @@ class MixedOperation(BaseSuperNetModule):
appended by forward arguments in the ``bound_type``."""
raise
NotImplementedError
()
def
__init__
(
self
,
module_kwargs
:
D
ict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
module_kwargs
:
d
ict
[
str
,
Any
])
->
None
:
# Concerned arguments
self
.
mutable_arguments
:
D
ict
[
str
,
ValueChoiceX
]
=
{}
self
.
mutable_arguments
:
d
ict
[
str
,
ValueChoiceX
]
=
{}
# Useful when retrieving arguments without ValueChoice
self
.
init_arguments
:
D
ict
[
str
,
Any
]
=
{
**
module_kwargs
}
self
.
init_arguments
:
d
ict
[
str
,
Any
]
=
{
**
module_kwargs
}
self
.
_fill_missing_init_arguments
()
# get init default
...
...
@@ -134,7 +136,7 @@ class MixedOperation(BaseSuperNetModule):
super_init_kwargs
[
key
]
=
value
# get all inner leaf value choices
self
.
_space_spec
:
D
ict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
(
self
.
mutable_arguments
.
values
())
self
.
_space_spec
:
d
ict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
(
list
(
self
.
mutable_arguments
.
values
())
)
super
().
__init__
(
**
super_init_kwargs
)
...
...
@@ -156,17 +158,17 @@ class MixedOperation(BaseSuperNetModule):
"""Find value choice in module's arguments and replace the whole module"""
has_valuechoice
=
False
if
isinstance
(
module
,
cls
.
bound_type
)
and
is_traceable
(
module
):
for
arg
in
itertools
.
chain
(
module
.
trace_args
,
module
.
trace_kwargs
.
values
()):
for
arg
in
itertools
.
chain
(
cast
(
list
,
module
.
trace_args
),
cast
(
dict
,
module
.
trace_kwargs
)
.
values
()):
if
isinstance
(
arg
,
ValueChoiceX
):
has_valuechoice
=
True
if
has_valuechoice
:
if
module
.
trace_args
:
raise
ValueError
(
'ValueChoice on class arguments cannot appear together with ``trace_args``. '
'Please enable ``kw_only`` on nni.trace.'
)
'Please enable ``kw_only`` on nni.trace.'
)
# save type and kwargs
mixed_op
=
cls
(
module
.
trace_kwargs
)
mixed_op
=
cls
(
cast
(
dict
,
module
.
trace_kwargs
)
)
if
'mixed_op_sampling'
not
in
mutate_kwargs
:
raise
ValueError
(
'Need to sampling policy of mixed op, but not found in `mutate_kwargs`.'
)
...
...
@@ -229,15 +231,15 @@ class MixedLinear(MixedOperation, nn.Linear):
out_features
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
in_features
=
_W
(
in_features
)
out_features
=
_W
(
out_features
)
in_features
_
=
_W
(
in_features
)
out_features
_
=
_W
(
out_features
)
weight
=
_S
(
self
.
weight
)[:
out_features
]
weight
=
_S
(
weight
)[:,
:
in_features
]
weight
=
_S
(
self
.
weight
)[:
out_features
_
]
weight
=
_S
(
weight
)[:,
:
in_features
_
]
if
self
.
bias
is
None
:
bias
=
self
.
bias
else
:
bias
=
_S
(
self
.
bias
)[:
out_features
]
bias
=
_S
(
self
.
bias
)[:
out_features
_
]
return
F
.
linear
(
inputs
,
weight
,
bias
)
...
...
@@ -278,7 +280,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
]
@
staticmethod
def
_to_tuple
(
value
:
scalar_or_scalar_dict
[
T
])
->
T
uple
[
T
,
T
]:
def
_to_tuple
(
value
:
scalar_or_scalar_dict
[
Any
])
->
t
uple
[
Any
,
Any
]:
if
not
isinstance
(
value
,
tuple
):
return
(
value
,
value
)
return
value
...
...
@@ -318,33 +320,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
,
groups
]):
raise
ValueError
(
'stride, dilation, groups does not support weighted sampling.'
)
in_channels
=
_W
(
in_channels
)
out_channels
=
_W
(
out_channels
)
in_channels
_
=
_W
(
in_channels
)
out_channels
_
=
_W
(
out_channels
)
# slice prefix
# For groups > 1, we use groups to slice input weights
weight
=
_S
(
self
.
weight
)[:
out_channels
]
weight
=
_S
(
weight
)[:,
:
in_channels
//
groups
]
weight
=
_S
(
self
.
weight
)[:
out_channels
_
]
weight
=
_S
(
weight
)[:,
:
in_channels
_
//
groups
]
# slice center
if
isinstance
(
kernel_size
,
dict
):
# If kernel size is a dict, ignore choices in padding.
if
isinstance
(
self
.
padding
,
str
):
raise
ValueError
(
f
'Use "
{
self
.
padding
}
" in padding is not supported.'
)
padding
=
self
.
padding
# max padding, must be a tuple
kernel_a
,
kernel_b
=
self
.
_to_tuple
(
kernel_size
)
kernel_a
,
kernel_b
=
_W
(
kernel_a
),
_W
(
kernel_b
)
kernel_a
_
,
kernel_b
_
=
_W
(
kernel_a
),
_W
(
kernel_b
)
max_kernel_a
,
max_kernel_b
=
self
.
kernel_size
# self.kernel_size must be a tuple
kernel_a_left
,
kernel_b_top
=
(
max_kernel_a
-
kernel_a
)
//
2
,
(
max_kernel_b
-
kernel_b
)
//
2
weight
=
_S
(
weight
)[:,
:,
kernel_a_left
:
kernel_a_left
+
kernel_a
,
kernel_b_top
:
kernel_b_top
+
kernel_b
]
kernel_a_left
,
kernel_b_top
=
(
max_kernel_a
-
kernel_a
_
)
//
2
,
(
max_kernel_b
-
kernel_b
_
)
//
2
weight
=
_S
(
weight
)[:,
:,
kernel_a_left
:
kernel_a_left
+
kernel_a
_
,
kernel_b_top
:
kernel_b_top
+
kernel_b
_
]
bias
=
_S
(
self
.
bias
)[:
out_channels
]
if
self
.
bias
is
not
None
else
None
bias
=
_S
(
self
.
bias
)[:
out_channels
_
]
if
self
.
bias
is
not
None
else
None
# The rest parameters only need to be converted to tuple
stride
=
self
.
_to_tuple
(
stride
)
dilation
=
self
.
_to_tuple
(
dilation
)
stride
_
=
self
.
_to_tuple
(
stride
)
dilation
_
=
self
.
_to_tuple
(
dilation
)
if
self
.
padding_mode
!=
'zeros'
:
return
F
.
conv2d
(
F
.
pad
(
inputs
,
self
.
_reversed_padding_repeated_twice
,
mode
=
self
.
padding_mode
),
weight
,
bias
,
stride
,
(
0
,
0
),
dilation
,
groups
)
return
F
.
conv2d
(
inputs
,
weight
,
bias
,
stride
,
padding
,
dilation
,
groups
)
weight
,
bias
,
stride
_
,
(
0
,
0
),
dilation
_
,
groups
)
return
F
.
conv2d
(
inputs
,
weight
,
bias
,
stride
_
,
cast
(
'int | tuple'
,
padding
)
,
dilation
_
,
groups
)
class
MixedBatchNorm2d
(
MixedOperation
,
nn
.
BatchNorm2d
):
...
...
@@ -388,13 +394,15 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
if
num_features
<
self
.
num_features
:
weight
=
weight
[:
num_features
]
bias
=
bias
[:
num_features
]
running_mean
=
running_mean
[:
num_features
]
running_var
=
running_var
[:
num_features
]
if
running_mean
is
not
None
:
running_mean
=
running_mean
[:
num_features
]
if
running_var
is
not
None
:
running_var
=
running_var
[:
num_features
]
if
self
.
training
:
bn_training
=
True
else
:
bn_training
=
(
self
.
running_mean
is
None
)
and
(
self
.
running_var
is
None
)
bn_training
=
(
running_mean
is
None
)
and
(
running_var
is
None
)
return
F
.
batch_norm
(
inputs
,
...
...
@@ -473,7 +481,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
_to_proj_slice
(
self
,
embed_dim
:
_W
)
->
L
ist
[
slice
]:
def
_to_proj_slice
(
self
,
embed_dim
:
_W
)
->
l
ist
[
slice
]:
# slice three parts, corresponding to q, k, v respectively
return
[
slice
(
embed_dim
),
...
...
@@ -484,12 +492,12 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
def
forward_with_args
(
self
,
embed_dim
:
int_or_int_dict
,
num_heads
:
int
,
kdim
:
Optional
[
int_or_int_dict
]
,
vdim
:
Optional
[
int_or_int_dict
]
,
kdim
:
int_or_int_dict
|
None
,
vdim
:
int_or_int_dict
|
None
,
dropout
:
float
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
need_weights
:
bool
=
True
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
T
uple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
key_padding_mask
:
torch
.
Tensor
|
None
=
None
,
need_weights
:
bool
=
True
,
attn_mask
:
torch
.
Tensor
|
None
=
None
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
num_heads
,
dropout
]):
raise
ValueError
(
'num_heads, dropout do not support weighted sampling.'
)
...
...
@@ -511,26 +519,26 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
else
:
used_embed_dim
=
embed_dim
embed_dim
=
_W
(
embed_dim
)
embed_dim
_
=
_W
(
embed_dim
)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias
:
Optional
[
Tensor
]
=
None
in_proj_weight
:
Optional
[
Tensor
]
=
None
in_proj_bias
:
Tensor
|
None
=
None
in_proj_weight
:
Tensor
|
None
=
None
if
self
.
in_proj_bias
is
not
None
:
in_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
in_proj_bias
))[
self
.
_to_proj_slice
(
embed_dim
)]
in_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
in_proj_bias
))[
self
.
_to_proj_slice
(
embed_dim
_
)]
if
self
.
in_proj_weight
is
not
None
:
in_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
in_proj_weight
))[
self
.
_to_proj_slice
(
embed_dim
),
:
embed_dim
]
in_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
in_proj_weight
))[
self
.
_to_proj_slice
(
embed_dim
_
),
:
embed_dim
_
]
bias_k
=
_S
(
cast
(
Tensor
,
self
.
bias_k
))[:,
:,
:
embed_dim
]
if
self
.
bias_k
is
not
None
else
None
bias_v
=
_S
(
cast
(
Tensor
,
self
.
bias_v
))[:,
:,
:
embed_dim
]
if
self
.
bias_v
is
not
None
else
None
out_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
weight
))[:
embed_dim
,
:
embed_dim
]
out_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
bias
))[:
embed_dim
]
if
self
.
out_proj
.
bias
is
not
None
else
None
bias_k
=
_S
(
cast
(
Tensor
,
self
.
bias_k
))[:,
:,
:
embed_dim
_
]
if
self
.
bias_k
is
not
None
else
None
bias_v
=
_S
(
cast
(
Tensor
,
self
.
bias_v
))[:,
:,
:
embed_dim
_
]
if
self
.
bias_v
is
not
None
else
None
out_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
weight
))[:
embed_dim
_
,
:
embed_dim
_
]
out_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
bias
))[:
embed_dim
_
]
if
self
.
out_proj
.
bias
is
not
None
else
None
if
not
qkv_same_embed_dim
:
q_proj
=
_S
(
cast
(
Tensor
,
self
.
q_proj_weight
))[:
embed_dim
,
:
embed_dim
]
k_proj
=
_S
(
cast
(
Tensor
,
self
.
k_proj_weight
))[:
embed_dim
]
q_proj
=
_S
(
cast
(
Tensor
,
self
.
q_proj_weight
))[:
embed_dim
_
,
:
embed_dim
_
]
k_proj
=
_S
(
cast
(
Tensor
,
self
.
k_proj_weight
))[:
embed_dim
_
]
k_proj
=
_S
(
k_proj
)[:,
:
_W
(
kdim
)]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim
]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim
_
]
v_proj
=
_S
(
v_proj
)[:,
:
_W
(
vdim
)]
# The rest part is basically same as pytorch
...
...
@@ -560,7 +568,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
return
attn_output
,
attn_output_weights
NATIVE_MIXED_OPERATIONS
:
L
ist
[
Type
[
MixedOperation
]]
=
[
NATIVE_MIXED_OPERATIONS
:
l
ist
[
Type
[
MixedOperation
]]
=
[
MixedLinear
,
MixedConv2d
,
MixedBatchNorm2d
,
...
...
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
View file @
05c7d6e9
...
...
@@ -9,7 +9,9 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor.
"""
from
typing
import
List
,
Tuple
,
Optional
,
cast
from
__future__
import
annotations
from
typing
import
cast
import
torch
import
torch.nn
as
nn
...
...
@@ -48,13 +50,13 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
paths
:
L
ist
[
T
uple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
paths
:
l
ist
[
t
uple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
paths
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
paths
))
*
1E-3
)
# like sampling-based methods, it has a ``_sampled``.
self
.
_sampled
:
Optional
[
str
]
=
None
self
.
_sample_idx
:
Optional
[
int
]
=
None
self
.
_sampled
:
str
|
None
=
None
self
.
_sample_idx
:
int
|
None
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
def
run_function
(
ops
,
active_id
,
**
kwargs
):
...
...
@@ -130,10 +132,10 @@ class ProxylessMixedInput(DifferentiableMixedInput):
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
n_candidates
,
n_chosen
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
n_candidates
)
*
1E-3
)
self
.
_sampled
:
Optional
[
int
]
=
None
self
.
_sampled
:
int
|
None
=
None
def
forward
(
self
,
inputs
):
def
run_function
(
active_sample
):
...
...
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
random
from
typing
import
Optional
,
List
,
Tuple
,
Union
,
Dict
,
Any
from
typing
import
Any
import
torch
import
torch.nn
as
nn
...
...
@@ -28,14 +30,14 @@ class PathSamplingLayer(BaseSuperNetModule):
Name of the choice.
"""
def
__init__
(
self
,
paths
:
L
ist
[
T
uple
[
str
,
nn
.
Module
]],
label
:
str
):
def
__init__
(
self
,
paths
:
l
ist
[
t
uple
[
str
,
nn
.
Module
]],
label
:
str
):
super
().
__init__
()
self
.
op_names
=
[]
for
name
,
module
in
paths
:
self
.
add_module
(
name
,
module
)
self
.
op_names
.
append
(
name
)
assert
self
.
op_names
,
'There has to be at least one op to choose from.'
self
.
_sampled
:
Optional
[
Union
[
L
ist
[
str
]
,
str
]]
=
None
# sampled can be either a list of indices or an index
self
.
_sampled
:
l
ist
[
str
]
|
str
|
None
=
None
# sampled can be either a list of indices or an index
self
.
label
=
label
def
resample
(
self
,
memo
):
...
...
@@ -89,7 +91,7 @@ class PathSamplingInput(BaseSuperNetModule):
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
_sampled
:
Optional
[
Union
[
L
ist
[
int
]
,
int
]]
=
None
self
.
_sampled
:
l
ist
[
int
]
|
int
|
None
=
None
self
.
label
=
label
def
_random_choose_n
(
self
):
...
...
@@ -159,11 +161,11 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
We sample the leaf nodes, and composits them into the values on arguments.
"""
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
None
:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self
.
_sampled
:
Optional
[
D
ict
[
str
,
Any
]
]
=
None
self
.
_sampled
:
d
ict
[
str
,
Any
]
|
None
=
None
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""Random sample for each leaf value choice."""
result
=
{}
space_spec
=
operation
.
search_space_spec
()
...
...
@@ -181,7 +183,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return
result
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
result
=
{}
space_spec
=
operation
.
search_space_spec
()
...
...
nni/retiarii/oneshot/pytorch/utils.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
from
collections
import
OrderedDict
from
typing
import
cast
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
,
Dataset
import
nni.retiarii.nn.pytorch
as
nn
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
...
...
@@ -155,7 +159,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
Returns
-------
L
ist[
T
uple[str, nn.Module]]
l
ist[
t
uple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
LayerChoice
,
nn
.
LayerChoice
),
modules
)
...
...
@@ -176,7 +180,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
Returns
-------
L
ist[
T
uple[str, nn.Module]]
l
ist[
t
uple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
InputChoice
,
nn
.
InputChoice
),
modules
)
...
...
@@ -200,15 +204,19 @@ class InterleavedTrainValDataLoader(DataLoader):
Example
--------
Fit your dataloaders into a parallel one.
>>> para_loader = InterleavedTrainValDataLoader(train_dataloader, val_dataloader)
Then you can use the ``para_loader`` as a normal training loader.
"""
def
__init__
(
self
,
train_dataloader
,
val_dataloader
):
self
.
train_loader
=
train_dataloader
self
.
val_loader
=
val_dataloader
def
__init__
(
self
,
train_dataloader
:
DataLoader
,
val_dataloader
:
DataLoader
|
list
[
DataLoader
]):
if
isinstance
(
val_dataloader
,
list
):
raise
TypeError
(
'Validation dataloader of type list is not supported.'
)
self
.
train_loader
:
DataLoader
=
train_dataloader
self
.
val_loader
:
DataLoader
=
val_dataloader
self
.
equal_len
=
len
(
train_dataloader
)
==
len
(
val_dataloader
)
self
.
train_longer
=
len
(
train_dataloader
)
>
len
(
val_dataloader
)
super
().
__init__
(
None
)
super
().
__init__
(
cast
(
Dataset
,
None
)
)
def
__iter__
(
self
):
self
.
train_iter
=
iter
(
self
.
train_loader
)
...
...
@@ -268,13 +276,17 @@ class ConcatenateTrainValDataLoader(DataLoader):
Example
--------
Fit your dataloaders into a concatenated one.
>>> concat_loader = ConcatenateTrainValDataLoader(train_dataloader, val_datalodaer)
Then you can use the ``concat_loader`` as a normal training loader.
"""
def
__init__
(
self
,
train_dataloader
,
val_dataloader
):
self
.
train_loader
=
train_dataloader
self
.
val_loader
=
val_dataloader
super
().
__init__
(
None
)
def
__init__
(
self
,
train_dataloader
:
DataLoader
,
val_dataloader
:
DataLoader
|
list
[
DataLoader
]):
if
isinstance
(
val_dataloader
,
list
):
raise
TypeError
(
'Validation dataloader of type list is not supported.'
)
self
.
train_loader
:
DataLoader
=
train_dataloader
self
.
val_loader
:
DataLoader
=
val_dataloader
super
().
__init__
(
cast
(
Dataset
,
None
))
def
__iter__
(
self
):
self
.
cur_iter
=
iter
(
self
.
train_loader
)
...
...
pyrightconfig.json
View file @
05c7d6e9
...
...
@@ -14,7 +14,6 @@
"nni/retiarii/execution/cgo_engine.py"
,
"nni/retiarii/execution/logical_optimizer"
,
"nni/retiarii/evaluator/pytorch/cgo"
,
"nni/retiarii/oneshot"
,
"nni/smartparam.py"
,
"nni/tools/annotation"
,
"nni/tools/gpu_tool"
,
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
05c7d6e9
...
...
@@ -130,13 +130,16 @@ def test_fit_api():
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.1
,
# for faster training
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
lightning
.
fit
(
lambda
:
MNISTModel
())
lightning
.
fit
(
MNISTModel
)
lightning
.
fit
(
MNISTModel
())
def
lightning
():
return
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.1
,
# for faster training
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
# Lightning will have some cache in models / trainers,
# which is problematic if we call fit multiple times.
lightning
().
fit
(
lambda
:
MNISTModel
())
lightning
().
fit
(
MNISTModel
)
lightning
().
fit
(
MNISTModel
())
_reset
()
...
...
test/ut/retiarii/test_oneshot.py
View file @
05c7d6e9
...
...
@@ -12,6 +12,7 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
Regression
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ValueChoice
from
nni.retiarii.strategy
import
BaseStrategy
class
DepthwiseSeparableConv
(
nn
.
Module
):
...
...
@@ -237,8 +238,12 @@ def _test_strategy(strategy_, support_value_choice=True):
]
for
(
base_model
,
evaluator
),
support_or_not
in
to_test
:
print
(
'Testing:'
,
type
(
strategy_
).
__name__
,
type
(
base_model
).
__name__
,
type
(
evaluator
).
__name__
,
support_or_not
)
experiment
=
RetiariiExperiment
(
base_model
,
evaluator
,
strategy
=
strategy_
)
if
isinstance
(
strategy_
,
BaseStrategy
):
strategy
=
strategy_
else
:
strategy
=
strategy_
(
base_model
,
evaluator
)
print
(
'Testing:'
,
type
(
strategy
).
__name__
,
type
(
base_model
).
__name__
,
type
(
evaluator
).
__name__
,
support_or_not
)
experiment
=
RetiariiExperiment
(
base_model
,
evaluator
,
strategy
=
strategy
)
config
=
RetiariiExeConfig
()
config
.
execution_engine
=
'oneshot'
...
...
@@ -263,7 +268,12 @@ def test_proxyless():
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_enas
():
_test_strategy
(
strategy
.
ENAS
())
def
strategy_fn
(
base_model
,
evaluator
):
if
isinstance
(
base_model
,
MultiHeadAttentionNet
):
return
strategy
.
ENAS
(
reward_metric_name
=
'val_mse'
)
return
strategy
.
ENAS
(
reward_metric_name
=
'val_acc'
)
_test_strategy
(
strategy_fn
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment