Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
05c7d6e9
"examples/vscode:/vscode.git/clone" did not exist on "fecc833699322f27f9618ca221c23fbac3b0979d"
Unverified
Commit
05c7d6e9
authored
Apr 28, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 28, 2022
Browse files
Typehint for oneshot NAS (#4811)
parent
cbac2c5c
Changes
20
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
379 additions
and
254 deletions
+379
-254
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+69
-41
nni/retiarii/oneshot/pytorch/darts.py
nni/retiarii/oneshot/pytorch/darts.py
+3
-0
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+19
-14
nni/retiarii/oneshot/pytorch/enas.py
nni/retiarii/oneshot/pytorch/enas.py
+23
-17
nni/retiarii/oneshot/pytorch/proxyless.py
nni/retiarii/oneshot/pytorch/proxyless.py
+3
-0
nni/retiarii/oneshot/pytorch/random.py
nni/retiarii/oneshot/pytorch/random.py
+5
-0
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+36
-22
nni/retiarii/oneshot/pytorch/strategy.py
nni/retiarii/oneshot/pytorch/strategy.py
+12
-8
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
+23
-20
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
+3
-2
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
...etiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
+22
-13
nni/retiarii/oneshot/pytorch/supermodule/base.py
nni/retiarii/oneshot/pytorch/supermodule/base.py
+11
-9
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
+27
-21
nni/retiarii/oneshot/pytorch/supermodule/operation.py
nni/retiarii/oneshot/pytorch/supermodule/operation.py
+59
-51
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
+8
-6
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
+10
-8
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+23
-11
pyrightconfig.json
pyrightconfig.json
+0
-1
test/ut/retiarii/test_lightning_trainer.py
test/ut/retiarii/test_lightning_trainer.py
+10
-7
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+13
-3
No files found.
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
warnings
import
warnings
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
Dict
,
Callable
,
Lis
t
,
Union
,
Any
,
Tuple
from
typing
import
Callable
,
Any
,
Dic
t
,
Union
,
Tuple
,
List
,
cast
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch.optim
as
optim
import
torch.optim
as
optim
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.optim.lr_scheduler
import
_LRScheduler
...
@@ -24,8 +27,8 @@ MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[
...
@@ -24,8 +27,8 @@ MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[
def
traverse_and_mutate_submodules
(
def
traverse_and_mutate_submodules
(
root_module
:
nn
.
Module
,
hooks
:
L
ist
[
MutationHook
],
mutate_kwargs
:
D
ict
[
str
,
Any
],
topdown
:
bool
=
True
root_module
:
nn
.
Module
,
hooks
:
l
ist
[
MutationHook
],
mutate_kwargs
:
d
ict
[
str
,
Any
],
topdown
:
bool
=
True
)
->
L
ist
[
BaseSuperNetModule
]:
)
->
l
ist
[
BaseSuperNetModule
]:
"""
"""
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
...
@@ -36,7 +39,7 @@ def traverse_and_mutate_submodules(
...
@@ -36,7 +39,7 @@ def traverse_and_mutate_submodules(
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
it's usually a ``pytorch_lightning.LightningModule``.
it's usually a ``pytorch_lightning.LightningModule``.
The mutation will be in-place on ``root_module``.
The mutation will be in-place on ``root_module``.
hooks :
L
ist[MutationHook]
hooks :
l
ist[MutationHook]
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
When a hook returns an module, the module will be replaced (mutated) to the new module.
When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict
mutate_kwargs : dict
...
@@ -47,7 +50,7 @@ def traverse_and_mutate_submodules(
...
@@ -47,7 +50,7 @@ def traverse_and_mutate_submodules(
Returns
Returns
----------
----------
modules :
D
ict[str, nn.Module]
modules :
d
ict[str, nn.Module]
The replace result.
The replace result.
"""
"""
memo
=
{}
memo
=
{}
...
@@ -101,7 +104,7 @@ def traverse_and_mutate_submodules(
...
@@ -101,7 +104,7 @@ def traverse_and_mutate_submodules(
return
module_list
return
module_list
def
no_default_hook
(
module
:
nn
.
Module
,
name
:
str
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
bool
:
def
no_default_hook
(
module
:
nn
.
Module
,
name
:
str
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
bool
:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet
# Forward IS NOT supernet
...
@@ -125,7 +128,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
...
@@ -125,7 +128,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
if
is_traceable
(
module
):
if
is_traceable
(
module
):
# check whether there is a value-choice in its arguments
# check whether there is a value-choice in its arguments
has_valuechoice
=
False
has_valuechoice
=
False
for
arg
in
chain
(
module
.
trace_args
,
module
.
trace_kwargs
.
values
()):
for
arg
in
chain
(
cast
(
list
,
module
.
trace_args
),
cast
(
dict
,
module
.
trace_kwargs
)
.
values
()):
if
isinstance
(
arg
,
ValueChoiceX
):
if
isinstance
(
arg
,
ValueChoiceX
):
has_valuechoice
=
True
has_valuechoice
=
True
break
break
...
@@ -139,7 +142,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
...
@@ -139,7 +142,7 @@ def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_k
class
BaseOneShotLightningModule
(
pl
.
LightningModule
):
class
BaseOneShotLightningModule
(
pl
.
LightningModule
):
_mutation_hooks_note
=
"""mutation_hooks :
L
ist[MutationHook]
_mutation_hooks_note
=
"""mutation_hooks :
l
ist[MutationHook]
Mutation hooks are callable that inputs an Module and returns a :class:`BaseSuperNetModule`.
Mutation hooks are callable that inputs an Module and returns a :class:`BaseSuperNetModule`.
They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules.
They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
For each submodule, the hook list are invoked subsequently,
...
@@ -194,36 +197,40 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -194,36 +197,40 @@ class BaseOneShotLightningModule(pl.LightningModule):
Attributes
Attributes
----------
----------
nas_modules :
L
ist[BaseSuperNetModule]
nas_modules :
l
ist[BaseSuperNetModule]
Modules that have been mutated, which the search algorithms should care about.
Modules that have been mutated, which the search algorithms should care about.
Parameters
Parameters
----------
----------
"""
+
_inner_module_note
+
_mutation_hooks_note
"""
+
_inner_module_note
+
_mutation_hooks_note
automatic_optimization
=
False
trainer
:
pl
.
Trainer
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
False
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Override this to define class-default mutation hooks."""
"""Override this to define class-default mutation hooks."""
return
[
no_default_hook
]
return
[
no_default_hook
]
def
mutate_kwargs
(
self
)
->
D
ict
[
str
,
Any
]:
def
mutate_kwargs
(
self
)
->
d
ict
[
str
,
Any
]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return
{}
return
{}
def
__init__
(
self
,
base_
model
:
pl
.
LightningModule
,
mutation_hooks
:
L
ist
[
MutationHook
]
=
None
):
def
__init__
(
self
,
model
:
pl
.
LightningModule
,
mutation_hooks
:
l
ist
[
MutationHook
]
|
None
=
None
):
super
().
__init__
()
super
().
__init__
()
assert
isinstance
(
base_
model
,
pl
.
LightningModule
)
assert
isinstance
(
model
,
pl
.
LightningModule
)
self
.
model
=
base_
model
self
.
model
=
model
# append the default hooks
# append the default hooks
mutation_hooks
=
(
mutation_hooks
or
[])
+
self
.
default_mutation_hooks
()
mutation_hooks
=
(
mutation_hooks
or
[])
+
self
.
default_mutation_hooks
()
# traverse the model, calling hooks on every submodule
# traverse the model, calling hooks on every submodule
self
.
nas_modules
:
L
ist
[
BaseSuperNetModule
]
=
traverse_and_mutate_submodules
(
self
.
nas_modules
:
l
ist
[
BaseSuperNetModule
]
=
traverse_and_mutate_submodules
(
self
.
model
,
mutation_hooks
,
self
.
mutate_kwargs
(),
topdown
=
True
)
self
.
model
,
mutation_hooks
,
self
.
mutate_kwargs
(),
topdown
=
True
)
def
search_space_spec
(
self
)
->
D
ict
[
str
,
ParameterSpec
]:
def
search_space_spec
(
self
)
->
d
ict
[
str
,
ParameterSpec
]:
"""Get the search space specification from ``nas_module``.
"""Get the search space specification from ``nas_module``.
Returns
Returns
...
@@ -236,7 +243,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -236,7 +243,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
result
.
update
(
module
.
search_space_spec
())
result
.
update
(
module
.
search_space_spec
())
return
result
return
result
def
resample
(
self
)
->
D
ict
[
str
,
Any
]:
def
resample
(
self
)
->
d
ict
[
str
,
Any
]:
"""Trigger the resample for each ``nas_module``.
"""Trigger the resample for each ``nas_module``.
Sometimes (e.g., in differentiable cases), it does nothing.
Sometimes (e.g., in differentiable cases), it does nothing.
...
@@ -250,7 +257,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -250,7 +257,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
result
.
update
(
module
.
resample
(
memo
=
result
))
result
.
update
(
module
.
resample
(
memo
=
result
))
return
result
return
result
def
export
(
self
)
->
D
ict
[
str
,
Any
]:
def
export
(
self
)
->
d
ict
[
str
,
Any
]:
"""
"""
Export the NAS result, ideally the best choice of each ``nas_module``.
Export the NAS result, ideally the best choice of each ``nas_module``.
You may implement an ``export`` method for your customized ``nas_module``.
You may implement an ``export`` method for your customized ``nas_module``.
...
@@ -291,12 +298,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -291,12 +298,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
arc_optimizers
=
[
arc_optimizers
]
arc_optimizers
=
[
arc_optimizers
]
self
.
arc_optim_count
=
len
(
arc_optimizers
)
self
.
arc_optim_count
=
len
(
arc_optimizers
)
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
try
:
# above v1.6
from
pytorch_lightning.core.optimizer
import
(
# pylint: disable=import-error
_configure_optimizers
,
# type: ignore
_configure_schedulers_automatic_opt
,
# type: ignore
_configure_schedulers_manual_opt
# type: ignore
)
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
# type: ignore
lr_schedulers
=
(
_configure_schedulers_automatic_opt
(
lr_schedulers
,
monitor
)
if
self
.
automatic_optimization
else
_configure_schedulers_manual_opt
(
lr_schedulers
)
)
except
ImportError
:
# under v1.5
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
self
.
trainer
.
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
self
.
trainer
.
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
# type: ignore
lr_schedulers
=
self
.
trainer
.
_configure_schedulers
(
lr_schedulers
,
monitor
,
not
self
.
automatic_optimization
)
lr_schedulers
=
self
.
trainer
.
_configure_schedulers
(
lr_schedulers
,
monitor
,
not
self
.
automatic_optimization
)
if
any
(
sch
[
"scheduler"
].
optimizer
not
in
w_optimizers
for
sch
in
lr_schedulers
):
if
any
(
sch
[
"scheduler"
].
optimizer
not
in
w_optimizers
for
sch
in
lr_schedulers
):
raise
Exception
(
raise
Exception
(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
...
@@ -312,7 +337,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -312,7 +337,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
# redirect the access to trainer/log to this module
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# but note that we might be missing other attributes,
# which could potentially be a problem
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
self
.
model
.
trainer
=
self
.
trainer
# type: ignore
self
.
model
.
log
=
self
.
log
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_train_start
()
return
self
.
model
.
on_train_start
()
...
@@ -359,7 +384,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -359,7 +384,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
Returns
Returns
----------
----------
arc_optimizers :
L
ist[Optimizer], Optimizer
arc_optimizers :
l
ist[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
"""
return
None
return
None
...
@@ -376,9 +401,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -376,9 +401,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
"""
"""
def
apply
(
lr_scheduler
):
def
apply
(
lr_scheduler
):
# single scheduler is called every epoch
# single scheduler is called every epoch
if
isinstance
(
lr_scheduler
,
_LRScheduler
)
and
\
if
isinstance
(
lr_scheduler
,
_LRScheduler
)
:
self
.
trainer
.
is_last_batch
:
if
self
.
trainer
.
is_last_batch
:
lr_scheduler
s
.
step
()
lr_scheduler
.
step
()
# lr_scheduler_config is called as configured
# lr_scheduler_config is called as configured
elif
isinstance
(
lr_scheduler
,
dict
):
elif
isinstance
(
lr_scheduler
,
dict
):
interval
=
lr_scheduler
[
'interval'
]
interval
=
lr_scheduler
[
'interval'
]
...
@@ -392,7 +417,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -392,7 +417,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
self
.
trainer
.
is_last_batch
and
self
.
trainer
.
is_last_batch
and
(
self
.
trainer
.
current_epoch
+
1
)
%
frequency
==
0
(
self
.
trainer
.
current_epoch
+
1
)
%
frequency
==
0
):
):
lr_scheduler
.
step
()
lr_scheduler
[
'scheduler'
]
.
step
()
lr_schedulers
=
self
.
lr_schedulers
()
lr_schedulers
=
self
.
lr_schedulers
()
...
@@ -402,7 +427,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -402,7 +427,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
else
:
else
:
apply
(
lr_schedulers
)
apply
(
lr_schedulers
)
def
call_
user
_optimizers
(
self
,
method
):
def
call_
weight
_optimizers
(
self
,
method
):
"""
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
...
@@ -418,10 +443,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -418,10 +443,12 @@ class BaseOneShotLightningModule(pl.LightningModule):
elif
method
==
'zero_grad'
:
elif
method
==
'zero_grad'
:
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
optimizers
=
self
.
user
_optimizers
optimizers
=
self
.
weight
_optimizers
()
if
optimizers
is
None
:
if
optimizers
is
None
:
return
return
assert
isinstance
(
optimizers
,
list
),
'Did you forget to set use_pl_optimizers to true?'
if
len
(
self
.
frequencies
)
>
0
:
if
len
(
self
.
frequencies
)
>
0
:
self
.
cur_optimizer_step
+=
1
self
.
cur_optimizer_step
+=
1
if
self
.
frequencies
[
self
.
cur_optimizer_index
]
==
self
.
cur_optimizer_step
:
if
self
.
frequencies
[
self
.
cur_optimizer_index
]
==
self
.
cur_optimizer_step
:
...
@@ -434,14 +461,13 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -434,14 +461,13 @@ class BaseOneShotLightningModule(pl.LightningModule):
for
optimizer
in
optimizers
:
for
optimizer
in
optimizers
:
apply_method
(
optimizer
,
method
)
apply_method
(
optimizer
,
method
)
@
property
def
architecture_optimizers
(
self
)
->
list
[
Optimizer
]
|
Optimizer
|
None
:
def
architecture_optimizers
(
self
):
"""
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in ``training_step``.
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in ``training_step``.
Returns
Returns
----------
----------
opts :
L
ist[Optimizer], Optimizer, None
opts :
l
ist[Optimizer], Optimizer, None
Architecture optimizers defined in ``configure_architecture_optimizers``. This will be None if there is no
Architecture optimizers defined in ``configure_architecture_optimizers``. This will be None if there is no
architecture optimizers.
architecture optimizers.
"""
"""
...
@@ -450,28 +476,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -450,28 +476,30 @@ class BaseOneShotLightningModule(pl.LightningModule):
# pylint: disable=unsubscriptable-object
# pylint: disable=unsubscriptable-object
arc_opts
=
opts
[:
self
.
arc_optim_count
]
arc_opts
=
opts
[:
self
.
arc_optim_count
]
if
len
(
arc_opts
)
==
1
:
if
len
(
arc_opts
)
==
1
:
arc_opts
=
arc_opts
[
0
]
return
cast
(
Optimizer
,
arc_opts
[
0
]
)
return
arc_opts
return
cast
(
List
[
Optimizer
],
arc_opts
)
# If there is only 1 optimizer and it is the architecture optimizer
# If there is only 1 optimizer and it is the architecture optimizer
if
self
.
arc_optim_count
==
1
:
if
self
.
arc_optim_count
==
1
:
return
opts
return
cast
(
Union
[
List
[
Optimizer
],
Optimizer
],
opts
)
return
None
return
None
@
property
def
weight_optimizers
(
self
)
->
list
[
Optimizer
]
|
Optimizer
|
None
:
def
user_optimizers
(
self
):
"""
"""
Get user optimizers from all optimizers. Use this to get user optimizers in ``training_step``.
Get user optimizers from all optimizers. Use this to get user optimizers in ``training_step``.
Returns
Returns
----------
----------
opts :
L
ist[Optimizer], Optimizer, None
opts :
l
ist[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
"""
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts
=
self
.
optimizers
()
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
# pylint: disable=unsubscriptable-object
return
opts
[
self
.
arc_optim_count
:]
return
cast
(
List
[
Optimizer
],
opts
[
self
.
arc_optim_count
:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer
# If there is only 1 optimizer and no architecture optimizer
if
self
.
arc_optim_count
==
0
:
if
self
.
arc_optim_count
==
0
:
return
opts
return
cast
(
Union
[
List
[
Optimizer
],
Optimizer
],
opts
)
return
None
return
None
nni/retiarii/oneshot/pytorch/darts.py
View file @
05c7d6e9
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
import
copy
import
copy
import
logging
import
logging
import
warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
import
torch
...
@@ -111,6 +112,8 @@ class DartsTrainer(BaseOneShotTrainer):
...
@@ -111,6 +112,8 @@ class DartsTrainer(BaseOneShotTrainer):
learning_rate
=
2.5E-3
,
batch_size
=
64
,
workers
=
4
,
learning_rate
=
2.5E-3
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
device
=
None
,
log_frequency
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
False
):
arc_learning_rate
=
3.0E-4
,
unrolled
=
False
):
warnings
.
warn
(
'DartsTrainer is deprecated. Please use strategy.DARTS instead.'
,
DeprecationWarning
)
self
.
model
=
model
self
.
model
=
model
self
.
loss
=
loss
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
metrics
=
metrics
...
...
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
05c7d6e9
...
@@ -3,9 +3,11 @@
...
@@ -3,9 +3,11 @@
"""Experimental version of differentiable one-shot implementation."""
"""Experimental version of differentiable one-shot implementation."""
from
typing
import
List
from
__future__
import
annotations
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.supermodule.differentiable
import
(
from
.supermodule.differentiable
import
(
...
@@ -45,7 +47,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
...
@@ -45,7 +47,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
)
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Replace modules with differentiable versions"""
"""Replace modules with differentiable versions"""
hooks
=
[
hooks
=
[
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedLayer
.
mutate
,
...
@@ -62,14 +64,16 @@ class DartsLightningModule(BaseOneShotLightningModule):
...
@@ -62,14 +64,16 @@ class DartsLightningModule(BaseOneShotLightningModule):
}
}
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
mutation_hooks
:
L
ist
[
MutationHook
]
=
None
,
mutation_hooks
:
l
ist
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0E-4
):
arc_learning_rate
:
float
=
3.0E-4
):
self
.
arc_learning_rate
=
arc_learning_rate
self
.
arc_learning_rate
=
arc_learning_rate
super
().
__init__
(
inner_module
,
mutation_hooks
=
mutation_hooks
)
super
().
__init__
(
inner_module
,
mutation_hooks
=
mutation_hooks
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
# grad manually
# grad manually
arc_optim
=
self
.
architecture_optimizers
arc_optim
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_optim
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_optim to be a single Optimizer, but found:
{
arc_optim
}
'
)
# The InterleavedTrainValDataLoader yields both train and val data in a batch
# The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch
,
val_batch
=
batch
trn_batch
,
val_batch
=
batch
...
@@ -88,12 +92,12 @@ class DartsLightningModule(BaseOneShotLightningModule):
...
@@ -88,12 +92,12 @@ class DartsLightningModule(BaseOneShotLightningModule):
# phase 2: model step
# phase 2: model step
self
.
resample
()
self
.
resample
()
self
.
call_
user
_optimizers
(
'zero_grad'
)
self
.
call_
weight
_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
manual_backward
(
w_step_loss
)
self
.
call_
user
_optimizers
(
'step'
)
self
.
call_
weight
_optimizers
(
'step'
)
self
.
call_lr_schedulers
(
batch_idx
)
self
.
call_lr_schedulers
(
batch_idx
)
...
@@ -107,7 +111,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
...
@@ -107,7 +111,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params
=
[]
ctrl_params
=
[]
for
m
in
self
.
nas_modules
:
for
m
in
self
.
nas_modules
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
# type: ignore
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
set
(
ctrl_params
)),
3.e-4
,
betas
=
(
0.5
,
0.999
),
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
set
(
ctrl_params
)),
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
weight_decay
=
1.0E-3
)
return
ctrl_optim
return
ctrl_optim
...
@@ -135,7 +139,7 @@ class ProxylessLightningModule(DartsLightningModule):
...
@@ -135,7 +139,7 @@ class ProxylessLightningModule(DartsLightningModule):
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
)
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Replace modules with gumbel-differentiable versions"""
"""Replace modules with gumbel-differentiable versions"""
hooks
=
[
hooks
=
[
ProxylessMixedLayer
.
mutate
,
ProxylessMixedLayer
.
mutate
,
...
@@ -147,7 +151,7 @@ class ProxylessLightningModule(DartsLightningModule):
...
@@ -147,7 +151,7 @@ class ProxylessLightningModule(DartsLightningModule):
def
finalize_grad
(
self
):
def
finalize_grad
(
self
):
for
m
in
self
.
nas_modules
:
for
m
in
self
.
nas_modules
:
m
.
finalize_grad
()
m
.
finalize_grad
()
# type: ignore
class
GumbelDartsLightningModule
(
DartsLightningModule
):
class
GumbelDartsLightningModule
(
DartsLightningModule
):
...
@@ -177,7 +181,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
...
@@ -177,7 +181,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
Learning rate for architecture optimizer. Default: 3.0e-4
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Replace modules with gumbel-differentiable versions"""
"""Replace modules with gumbel-differentiable versions"""
hooks
=
[
hooks
=
[
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedLayer
.
mutate
,
...
@@ -195,7 +199,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
...
@@ -195,7 +199,7 @@ class GumbelDartsLightningModule(DartsLightningModule):
}
}
def
__init__
(
self
,
inner_module
,
def
__init__
(
self
,
inner_module
,
mutation_hooks
:
L
ist
[
MutationHook
]
=
None
,
mutation_hooks
:
l
ist
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0e-4
,
arc_learning_rate
:
float
=
3.0e-4
,
gumbel_temperature
:
float
=
1.
,
gumbel_temperature
:
float
=
1.
,
use_temp_anneal
:
bool
=
False
,
use_temp_anneal
:
bool
=
False
,
...
@@ -206,12 +210,13 @@ class GumbelDartsLightningModule(DartsLightningModule):
...
@@ -206,12 +210,13 @@ class GumbelDartsLightningModule(DartsLightningModule):
self
.
use_temp_anneal
=
use_temp_anneal
self
.
use_temp_anneal
=
use_temp_anneal
self
.
min_temp
=
min_temp
self
.
min_temp
=
min_temp
def
on_epoch_
start
(
self
):
def
on_
train_
epoch_
end
(
self
):
if
self
.
use_temp_anneal
:
if
self
.
use_temp_anneal
:
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
for
module
in
self
.
nas_modules
:
for
module
in
self
.
nas_modules
:
module
.
_softmax
.
temp
=
self
.
temp
if
hasattr
(
module
,
'_softmax'
):
module
.
_softmax
.
temp
=
self
.
temp
# type: ignore
return
self
.
model
.
on_epoch_
start
()
return
self
.
model
.
on_
train_
epoch_
end
()
nni/retiarii/oneshot/pytorch/enas.py
View file @
05c7d6e9
...
@@ -2,10 +2,14 @@
...
@@ -2,10 +2,14 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
import
warnings
from
typing
import
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
from
torch.utils.data
import
SubsetRandomSampler
,
DataLoader
from
..interface
import
BaseOneShotTrainer
from
..interface
import
BaseOneShotTrainer
from
.random
import
PathSamplingLayerChoice
,
PathSamplingInputChoice
from
.random
import
PathSamplingLayerChoice
,
PathSamplingInputChoice
...
@@ -113,9 +117,9 @@ class ReinforceController(nn.Module):
...
@@ -113,9 +117,9 @@ class ReinforceController(nn.Module):
self
.
_h
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
self
.
_h
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
dtype
=
self
.
_inputs
.
dtype
,
dtype
=
self
.
_inputs
.
dtype
,
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
self
.
sample_log_prob
=
0
self
.
sample_log_prob
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
self
.
sample_entropy
=
0
self
.
sample_entropy
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
self
.
sample_skip_penalty
=
0
self
.
sample_skip_penalty
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
def
_lstm_next_step
(
self
):
def
_lstm_next_step
(
self
):
self
.
_h
,
self
.
_c
=
self
.
lstm
(
self
.
_inputs
,
(
self
.
_h
,
self
.
_c
))
self
.
_h
,
self
.
_c
=
self
.
lstm
(
self
.
_inputs
,
(
self
.
_h
,
self
.
_c
))
...
@@ -143,7 +147,7 @@ class ReinforceController(nn.Module):
...
@@ -143,7 +147,7 @@ class ReinforceController(nn.Module):
if
sampled
.
sum
().
item
():
if
sampled
.
sum
().
item
():
self
.
_inputs
=
(
torch
.
sum
(
self
.
embedding
[
field
.
name
](
sampled
.
view
(
-
1
)),
0
)
/
(
1.
+
torch
.
sum
(
sampled
))).
unsqueeze
(
0
)
self
.
_inputs
=
(
torch
.
sum
(
self
.
embedding
[
field
.
name
](
sampled
.
view
(
-
1
)),
0
)
/
(
1.
+
torch
.
sum
(
sampled
))).
unsqueeze
(
0
)
else
:
else
:
self
.
_inputs
=
torch
.
zeros
(
1
,
self
.
lstm_size
,
device
=
self
.
embedding
[
field
.
name
].
weight
.
device
)
self
.
_inputs
=
torch
.
zeros
(
1
,
self
.
lstm_size
,
device
=
self
.
embedding
[
field
.
name
].
weight
.
device
)
# type: ignore
sampled
=
sampled
.
detach
().
cpu
().
numpy
().
tolist
()
sampled
=
sampled
.
detach
().
cpu
().
numpy
().
tolist
()
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
...
@@ -205,6 +209,8 @@ class EnasTrainer(BaseOneShotTrainer):
...
@@ -205,6 +209,8 @@ class EnasTrainer(BaseOneShotTrainer):
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
grad_clip
=
5.
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
grad_clip
=
5.
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
ctrl_lr
=
0.00035
,
ctrl_steps_aggregate
=
20
,
ctrl_kwargs
=
None
):
ctrl_lr
=
0.00035
,
ctrl_steps_aggregate
=
20
,
ctrl_kwargs
=
None
):
warnings
.
warn
(
'EnasTrainer is deprecated. Please use strategy.ENAS instead.'
,
DeprecationWarning
)
self
.
model
=
model
self
.
model
=
model
self
.
loss
=
loss
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
metrics
=
metrics
...
@@ -246,13 +252,13 @@ class EnasTrainer(BaseOneShotTrainer):
...
@@ -246,13 +252,13 @@ class EnasTrainer(BaseOneShotTrainer):
n_train
=
len
(
self
.
dataset
)
n_train
=
len
(
self
.
dataset
)
split
=
n_train
//
2
split
=
n_train
//
2
indices
=
list
(
range
(
n_train
))
indices
=
list
(
range
(
n_train
))
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
-
split
])
train_sampler
=
SubsetRandomSampler
(
indices
[:
-
split
])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
-
split
:])
valid_sampler
=
SubsetRandomSampler
(
indices
[
-
split
:])
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
self
.
train_loader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size
,
batch_size
=
self
.
batch_size
,
sampler
=
train_sampler
,
sampler
=
train_sampler
,
num_workers
=
self
.
workers
)
num_workers
=
self
.
workers
)
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset
,
self
.
valid_loader
=
DataLoader
(
self
.
dataset
,
batch_size
=
self
.
batch_size
,
batch_size
=
self
.
batch_size
,
sampler
=
valid_sampler
,
sampler
=
valid_sampler
,
num_workers
=
self
.
workers
)
num_workers
=
self
.
workers
)
...
@@ -294,15 +300,15 @@ class EnasTrainer(BaseOneShotTrainer):
...
@@ -294,15 +300,15 @@ class EnasTrainer(BaseOneShotTrainer):
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
:
if
self
.
entropy_weight
:
reward
+=
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
reward
+=
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
loss
=
self
.
controller
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
loss
=
self
.
controller
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
if
self
.
skip_weight
:
loss
+=
self
.
skip_weight
*
self
.
controller
.
sample_skip_penalty
loss
+=
self
.
skip_weight
*
self
.
controller
.
sample_skip_penalty
metrics
[
'reward'
]
=
reward
metrics
[
'reward'
]
=
reward
metrics
[
'loss'
]
=
loss
.
item
()
metrics
[
'loss'
]
=
loss
.
item
()
metrics
[
'ent'
]
=
self
.
controller
.
sample_entropy
.
item
()
metrics
[
'ent'
]
=
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
metrics
[
'log_prob'
]
=
self
.
controller
.
sample_log_prob
.
item
()
metrics
[
'log_prob'
]
=
self
.
controller
.
sample_log_prob
.
item
()
# type: ignore
metrics
[
'baseline'
]
=
self
.
baseline
metrics
[
'baseline'
]
=
self
.
baseline
metrics
[
'skip'
]
=
self
.
controller
.
sample_skip_penalty
metrics
[
'skip'
]
=
self
.
controller
.
sample_skip_penalty
...
...
nni/retiarii/oneshot/pytorch/proxyless.py
View file @
05c7d6e9
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
# type: ignore
# type: ignore
import
logging
import
logging
import
warnings
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -230,6 +231,8 @@ class ProxylessTrainer(BaseOneShotTrainer):
...
@@ -230,6 +231,8 @@ class ProxylessTrainer(BaseOneShotTrainer):
grad_reg_loss_type
=
None
,
grad_reg_loss_params
=
None
,
grad_reg_loss_type
=
None
,
grad_reg_loss_params
=
None
,
applied_hardware
=
None
,
dummy_input
=
(
1
,
3
,
224
,
224
),
applied_hardware
=
None
,
dummy_input
=
(
1
,
3
,
224
,
224
),
ref_latency
=
65.0
):
ref_latency
=
65.0
):
warnings
.
warn
(
'ProxylessTrainer is deprecated. Please use strategy.Proxyless instead.'
,
DeprecationWarning
)
self
.
model
=
model
self
.
model
=
model
self
.
loss
=
loss
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
metrics
=
metrics
...
...
nni/retiarii/oneshot/pytorch/random.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
# type: ignore
import
logging
import
logging
import
random
import
random
import
warnings
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -122,6 +125,8 @@ class SinglePathTrainer(BaseOneShotTrainer):
...
@@ -122,6 +125,8 @@ class SinglePathTrainer(BaseOneShotTrainer):
def
__init__
(
self
,
model
,
loss
,
metrics
,
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
):
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
):
warnings
.
warn
(
'SinglePathTrainer is deprecated. Please use strategy.RandomOneShot instead.'
,
DeprecationWarning
)
self
.
model
=
model
self
.
model
=
model
self
.
loss
=
loss
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
metrics
=
metrics
...
...
nni/retiarii/oneshot/pytorch/sampling.py
View file @
05c7d6e9
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
"""Experimental version of sampling-based one-shot implementation."""
"""Experimental version of sampling-based one-shot implementation."""
from
typing
import
Dict
,
Any
,
List
from
__future__
import
annotations
from
typing
import
Any
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
...
@@ -33,9 +34,11 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
...
@@ -33,9 +34,11 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
)
)
# turn on automatic optimization because nothing interesting is going on here.
# turn on automatic optimization because nothing interesting is going on here.
automatic_optimization
=
True
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
True
def
default_mutation_hooks
(
self
)
->
L
ist
[
MutationHook
]:
def
default_mutation_hooks
(
self
)
->
l
ist
[
MutationHook
]:
"""Replace modules with differentiable versions"""
"""Replace modules with differentiable versions"""
hooks
=
[
hooks
=
[
PathSamplingLayer
.
mutate
,
PathSamplingLayer
.
mutate
,
...
@@ -80,6 +83,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
...
@@ -80,6 +83,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Number of steps that will be aggregated into one mini-batch for RL controller.
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_grad_clip : float
ctrl_grad_clip : float
Gradient clipping value of controller.
Gradient clipping value of controller.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
If there are multiple, it will find the metric with key name ``reward_metric_name``,
which is "default" by default.
Otherwise it raises an exception indicating multiple metrics are found.
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
__doc__
=
_enas_note
.
format
(
__doc__
=
_enas_note
.
format
(
...
@@ -87,23 +96,26 @@ class EnasLightningModule(RandomSamplingLightningModule):
...
@@ -87,23 +96,26 @@ class EnasLightningModule(RandomSamplingLightningModule):
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
)
automatic_optimization
=
False
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
False
def
__init__
(
self
,
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
inner_module
:
pl
.
LightningModule
,
*
,
*
,
ctrl_kwargs
:
D
ict
[
str
,
Any
]
=
None
,
ctrl_kwargs
:
d
ict
[
str
,
Any
]
|
None
=
None
,
entropy_weight
:
float
=
1e-4
,
entropy_weight
:
float
=
1e-4
,
skip_weight
:
float
=
.
8
,
skip_weight
:
float
=
.
8
,
baseline_decay
:
float
=
.
999
,
baseline_decay
:
float
=
.
999
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_grad_clip
:
float
=
0
,
ctrl_grad_clip
:
float
=
0
,
mutation_hooks
:
List
[
MutationHook
]
=
None
):
reward_metric_name
:
str
|
None
=
None
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
):
super
().
__init__
(
inner_module
,
mutation_hooks
)
super
().
__init__
(
inner_module
,
mutation_hooks
)
# convert parameter spec to legacy ReinforceField
# convert parameter spec to legacy ReinforceField
# this part will be refactored
# this part will be refactored
self
.
nas_fields
:
L
ist
[
ReinforceField
]
=
[]
self
.
nas_fields
:
l
ist
[
ReinforceField
]
=
[]
for
name
,
param_spec
in
self
.
search_space_spec
().
items
():
for
name
,
param_spec
in
self
.
search_space_spec
().
items
():
if
param_spec
.
chosen_size
not
in
(
1
,
None
):
if
param_spec
.
chosen_size
not
in
(
1
,
None
):
raise
ValueError
(
'ENAS does not support n_chosen to be values other than 1 or None.'
)
raise
ValueError
(
'ENAS does not support n_chosen to be values other than 1 or None.'
)
...
@@ -116,6 +128,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
...
@@ -116,6 +128,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
self
.
baseline
=
0.
self
.
baseline
=
0.
self
.
ctrl_steps_aggregate
=
ctrl_steps_aggregate
self
.
ctrl_steps_aggregate
=
ctrl_steps_aggregate
self
.
ctrl_grad_clip
=
ctrl_grad_clip
self
.
ctrl_grad_clip
=
ctrl_grad_clip
self
.
reward_metric_name
=
reward_metric_name
def
configure_architecture_optimizers
(
self
):
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
...
@@ -127,34 +140,35 @@ class EnasLightningModule(RandomSamplingLightningModule):
...
@@ -127,34 +140,35 @@ class EnasLightningModule(RandomSamplingLightningModule):
if
source
==
'train'
:
if
source
==
'train'
:
# step 1: train model params
# step 1: train model params
self
.
resample
()
self
.
resample
()
self
.
call_
user
_optimizers
(
'zero_grad'
)
self
.
call_
weight
_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
loss_and_metrics
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
manual_backward
(
w_step_loss
)
self
.
call_
user
_optimizers
(
'step'
)
self
.
call_
weight
_optimizers
(
'step'
)
return
loss_and_metrics
return
loss_and_metrics
if
source
==
'val'
:
if
source
==
'val'
:
# step 2: train ENAS agent
# step 2: train ENAS agent
x
,
y
=
batch
arc_opt
=
self
.
architecture_optimizers
()
arc_opt
=
self
.
architecture_optimizers
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
arc_opt
.
zero_grad
()
self
.
resample
()
self
.
resample
()
with
torch
.
no_grad
():
self
.
model
.
validation_step
(
batch
,
batch_idx
)
logits
=
self
.
model
(
x
)
# use the default metric of self.model as reward function
# use the default metric of self.model as reward function
if
len
(
self
.
model
.
metrics
)
==
1
:
if
len
(
self
.
trainer
.
callback_
metrics
)
==
1
:
_
,
metric
=
next
(
iter
(
self
.
model
.
metrics
.
items
()))
_
,
metric
=
next
(
iter
(
self
.
trainer
.
callback_
metrics
.
items
()))
else
:
else
:
if
'default'
not
in
self
.
model
.
metrics
.
keys
():
metric_name
=
self
.
reward_metric_name
or
'default'
raise
KeyError
(
'model.metrics should contain a ``default`` key when'
if
metric_name
not
in
self
.
trainer
.
callback_metrics
:
'there are multiple metrics'
)
raise
KeyError
(
f
'Model reported metrics should contain a ``
{
metric_name
}
`` key but '
metric
=
self
.
model
.
metrics
[
'default'
]
f
'found multiple metrics without default:
{
self
.
trainer
.
callback_metrics
.
keys
()
}
'
)
metric
=
self
.
trainer
.
callback_metrics
[
metric_name
]
reward
:
float
=
metric
.
item
()
reward
=
metric
(
logits
,
y
)
if
self
.
entropy_weight
:
if
self
.
entropy_weight
:
reward
=
reward
+
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
reward
=
reward
+
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
rnn_step_loss
=
self
.
controller
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
rnn_step_loss
=
self
.
controller
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
if
self
.
skip_weight
:
...
@@ -183,7 +197,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
...
@@ -183,7 +197,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
_interpret_controller_sampling_result
(
self
.
controller
.
resample
())
return
self
.
_interpret_controller_sampling_result
(
self
.
controller
.
resample
())
def
_interpret_controller_sampling_result
(
self
,
sample
:
D
ict
[
str
,
int
])
->
D
ict
[
str
,
Any
]:
def
_interpret_controller_sampling_result
(
self
,
sample
:
d
ict
[
str
,
int
])
->
d
ict
[
str
,
Any
]:
"""Convert ``{label: index}`` to ``{label: name}``"""
"""Convert ``{label: index}`` to ``{label: name}``"""
space_spec
=
self
.
search_space_spec
()
space_spec
=
self
.
search_space_spec
()
for
key
in
list
(
sample
.
keys
()):
for
key
in
list
(
sample
.
keys
()):
...
...
nni/retiarii/oneshot/pytorch/strategy.py
View file @
05c7d6e9
...
@@ -10,8 +10,10 @@ For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to b
...
@@ -10,8 +10,10 @@ For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to b
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""
"""
from
__future__
import
annotations
import
warnings
import
warnings
from
typing
import
Any
,
List
,
Optional
,
Type
,
Union
,
Tupl
e
from
typing
import
Any
,
Typ
e
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
@@ -33,10 +35,10 @@ class OneShotStrategy(BaseStrategy):
...
@@ -33,10 +35,10 @@ class OneShotStrategy(BaseStrategy):
self
.
oneshot_module
=
oneshot_module
self
.
oneshot_module
=
oneshot_module
self
.
oneshot_kwargs
=
kwargs
self
.
oneshot_kwargs
=
kwargs
self
.
model
:
Optional
[
BaseOneShotLightningModule
]
=
None
self
.
model
:
BaseOneShotLightningModule
|
None
=
None
def
_get_dataloader
(
self
,
train_dataloader
:
DataLoader
,
val_dataloaders
:
DataLoader
)
\
def
_get_dataloader
(
self
,
train_dataloader
:
DataLoader
,
val_dataloaders
:
DataLoader
|
list
[
DataLoader
]
)
\
->
Union
[
DataLoader
,
T
uple
[
DataLoader
,
DataLoader
]
]
:
->
DataLoader
|
t
uple
[
DataLoader
,
DataLoader
]:
"""
"""
One-shot strategy typically requires a customized dataloader.
One-shot strategy typically requires a customized dataloader.
...
@@ -51,9 +53,9 @@ class OneShotStrategy(BaseStrategy):
...
@@ -51,9 +53,9 @@ class OneShotStrategy(BaseStrategy):
_reason
=
'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
_reason
=
'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
py_model
:
nn
.
Module
=
base_model
.
python_object
if
not
isinstance
(
base_model
.
python_object
,
nn
.
Module
):
if
not
isinstance
(
py_model
,
nn
.
Module
):
raise
TypeError
(
'Model is not a nn.Module. '
+
_reason
)
raise
TypeError
(
'Model is not a nn.Module. '
+
_reason
)
py_model
:
nn
.
Module
=
base_model
.
python_object
if
applied_mutators
:
if
applied_mutators
:
raise
ValueError
(
'Mutator is not empty. '
+
_reason
)
raise
ValueError
(
'Mutator is not empty. '
+
_reason
)
...
@@ -64,8 +66,10 @@ class OneShotStrategy(BaseStrategy):
...
@@ -64,8 +66,10 @@ class OneShotStrategy(BaseStrategy):
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
.
set_model
(
py_model
)
evaluator_module
.
set_model
(
py_model
)
self
.
model
:
BaseOneShotLightningModule
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
self
.
model
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
evaluator
:
Lightning
=
base_model
.
evaluator
evaluator
:
Lightning
=
base_model
.
evaluator
if
evaluator
.
train_dataloader
is
None
or
evaluator
.
val_dataloaders
is
None
:
raise
TypeError
(
'Train or val dataloader is not set.'
)
dataloader
=
self
.
_get_dataloader
(
evaluator
.
train_dataloader
,
evaluator
.
val_dataloaders
)
dataloader
=
self
.
_get_dataloader
(
evaluator
.
train_dataloader
,
evaluator
.
val_dataloaders
)
if
isinstance
(
dataloader
,
tuple
):
if
isinstance
(
dataloader
,
tuple
):
dataloader
,
val_loader
=
dataloader
dataloader
,
val_loader
=
dataloader
...
@@ -73,7 +77,7 @@ class OneShotStrategy(BaseStrategy):
...
@@ -73,7 +77,7 @@ class OneShotStrategy(BaseStrategy):
else
:
else
:
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
)
evaluator
.
trainer
.
fit
(
self
.
model
,
dataloader
)
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
L
ist
[
Any
]:
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
l
ist
[
Any
]:
if
self
.
model
is
None
:
if
self
.
model
is
None
:
raise
RuntimeError
(
'One-shot strategy needs to be run before export.'
)
raise
RuntimeError
(
'One-shot strategy needs to be run before export.'
)
if
top_k
!=
1
:
if
top_k
!=
1
:
...
...
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
View file @
05c7d6e9
...
@@ -26,8 +26,10 @@ The fixed/weighted slice is fed into ``_slice_weight``,
...
@@ -26,8 +26,10 @@ The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor.
which interprets the slice and apply it on a tensor.
"""
"""
from
__future__
import
annotations
import
operator
import
operator
from
typing
import
Tuple
,
Union
,
List
,
Dict
,
Callable
,
Optional
,
Iterator
,
TypeVar
,
Any
,
Generic
,
cast
from
typing
import
Callable
,
Iterator
,
TypeVar
,
Any
,
Optional
,
Tuple
,
Union
,
List
,
Dict
,
Generic
,
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -58,8 +60,8 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
...
@@ -58,8 +60,8 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
for
i
in
range
(
len
(
slice_
)):
for
i
in
range
(
len
(
slice_
)):
if
isinstance
(
slice_
[
i
],
list
):
if
isinstance
(
slice_
[
i
],
list
):
# convert list of slices to mask
# convert list of slices to mask
mask
=
np
.
zeros
(
shape
[
i
],
dtype
=
np
.
bool
)
mask
=
np
.
zeros
(
shape
[
i
],
dtype
=
np
.
bool
)
# type: ignore
for
sl
in
slice_
[
i
]:
for
sl
in
cast
(
List
[
slice
],
slice_
[
i
]
)
:
mask
[
sl
]
=
1
mask
[
sl
]
=
1
result
.
append
(
mask
)
result
.
append
(
mask
)
else
:
else
:
...
@@ -67,7 +69,7 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
...
@@ -67,7 +69,7 @@ def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slic
return
tuple
(
result
)
return
tuple
(
result
)
def
_slice_weight
(
weight
:
T
,
slice_
:
Union
[
multidim_slice
,
L
ist
[
T
uple
[
multidim_slice
,
float
]]
]
)
->
T
:
def
_slice_weight
(
weight
:
T
,
slice_
:
multidim_slice
|
l
ist
[
t
uple
[
multidim_slice
,
float
]])
->
T
:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
...
@@ -84,27 +86,27 @@ def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_s
...
@@ -84,27 +86,27 @@ def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_s
# create a mask with weight w
# create a mask with weight w
with
torch
.
no_grad
():
with
torch
.
no_grad
():
mask
=
zeros_like
(
weight
)
mask
=
zeros_like
(
weight
)
mask
[
_eliminate_list_slice
(
weight
.
shape
,
sl
)]
=
1
mask
[
_eliminate_list_slice
(
weight
.
shape
,
sl
)]
=
1
# type: ignore
# track gradients here
# track gradients here
masks
.
append
(
(
mask
*
wt
)
)
masks
.
append
(
mask
*
wt
)
# type: ignore
masks
=
sum
(
masks
)
masks
=
sum
(
masks
)
return
masks
*
weight
return
masks
*
weight
# type: ignore
else
:
else
:
# for unweighted case, we slice it directly.
# for unweighted case, we slice it directly.
def
_do_slice
(
arr
,
slice_
):
def
_do_slice
(
arr
,
slice_
):
return
arr
[
_eliminate_list_slice
(
arr
.
shape
,
slice_
)]
return
arr
[
_eliminate_list_slice
(
arr
.
shape
,
slice_
)]
# type: ignore
# sometimes, we don't need slice.
# sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr
=
np
.
zeros
(
weight
.
shape
,
dtype
=
np
.
bool
)
dummy_arr
=
np
.
zeros
(
weight
.
shape
,
dtype
=
np
.
bool
)
# type: ignore
no_effect
=
_do_slice
(
dummy_arr
,
slice_
).
shape
==
dummy_arr
.
shape
no_effect
=
cast
(
Any
,
_do_slice
(
dummy_arr
,
slice_
)
)
.
shape
==
dummy_arr
.
shape
if
no_effect
:
if
no_effect
:
return
weight
return
weight
...
@@ -128,14 +130,14 @@ class Slicable(Generic[T]):
...
@@ -128,14 +130,14 @@ class Slicable(Generic[T]):
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
self
.
weight
=
weight
self
.
weight
=
weight
def
__getitem__
(
self
,
index
:
Union
[
slice_type
,
multidim_slice
]
)
->
T
:
def
__getitem__
(
self
,
index
:
slice_type
|
multidim_slice
)
->
T
:
if
not
isinstance
(
index
,
tuple
):
if
not
isinstance
(
index
,
tuple
):
index
=
(
index
,
)
index
=
(
index
,
)
index
=
cast
(
multidim_slice
,
index
)
index
=
cast
(
multidim_slice
,
index
)
# Get the dict value in index's leafs
# Get the dict value in index's leafs
# There can be at most one dict
# There can be at most one dict
leaf_dict
:
Optional
[
D
ict
[
int
,
float
]
]
=
None
leaf_dict
:
d
ict
[
int
,
float
]
|
None
=
None
for
maybe_weighted
in
_iterate_over_multidim_slice
(
index
):
for
maybe_weighted
in
_iterate_over_multidim_slice
(
index
):
for
d
in
maybe_weighted
.
leaf_values
():
for
d
in
maybe_weighted
.
leaf_values
():
if
isinstance
(
d
,
dict
):
if
isinstance
(
d
,
dict
):
...
@@ -166,10 +168,10 @@ class MaybeWeighted:
...
@@ -166,10 +168,10 @@ class MaybeWeighted:
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
value
:
Optional
[
int_or_int_dict
]
=
None
,
*
,
value
:
int_or_int_dict
|
None
=
None
,
*
,
lhs
:
Optional
[
Union
[
'MaybeWeighted'
,
int
]]
=
None
,
lhs
:
'MaybeWeighted'
|
int
|
None
=
None
,
rhs
:
Optional
[
Union
[
'MaybeWeighted'
,
int
]]
=
None
,
rhs
:
'MaybeWeighted'
|
int
|
None
=
None
,
operation
:
Optional
[
Callable
[[
int
,
int
],
int
]]
=
None
):
operation
:
Callable
[[
int
_or_int_dict
,
int_or_int_dict
],
int_or_int_dict
]
|
None
=
None
):
if
operation
is
None
:
if
operation
is
None
:
if
not
isinstance
(
value
,
(
int
,
dict
)):
if
not
isinstance
(
value
,
(
int
,
dict
)):
raise
TypeError
(
f
'Unsupported value type:
{
type
(
value
)
}
'
)
raise
TypeError
(
f
'Unsupported value type:
{
type
(
value
)
}
'
)
...
@@ -178,7 +180,7 @@ class MaybeWeighted:
...
@@ -178,7 +180,7 @@ class MaybeWeighted:
self
.
rhs
=
rhs
self
.
rhs
=
rhs
self
.
operation
=
operation
self
.
operation
=
operation
def
leaf_values
(
self
)
->
Iterator
[
Dict
[
int
,
float
]
]:
def
leaf_values
(
self
)
->
Iterator
[
int_or_int_dict
]:
"""Iterate over values on leaf nodes."""
"""Iterate over values on leaf nodes."""
if
self
.
value
is
not
None
:
if
self
.
value
is
not
None
:
yield
self
.
value
yield
self
.
value
...
@@ -188,7 +190,7 @@ class MaybeWeighted:
...
@@ -188,7 +190,7 @@ class MaybeWeighted:
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
yield
from
self
.
rhs
.
leaf_values
()
yield
from
self
.
rhs
.
leaf_values
()
def
evaluate
(
self
,
value_fn
:
_value_fn_type
=
None
)
->
int
:
def
evaluate
(
self
,
value_fn
:
_value_fn_type
=
None
)
->
int
_or_int_dict
:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
"""
"""
...
@@ -200,11 +202,12 @@ class MaybeWeighted:
...
@@ -200,11 +202,12 @@ class MaybeWeighted:
if
isinstance
(
self
.
lhs
,
MaybeWeighted
):
if
isinstance
(
self
.
lhs
,
MaybeWeighted
):
eval_lhs
=
self
.
lhs
.
evaluate
(
value_fn
)
eval_lhs
=
self
.
lhs
.
evaluate
(
value_fn
)
else
:
else
:
eval_lhs
=
self
.
lhs
eval_lhs
=
cast
(
int
,
self
.
lhs
)
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
eval_rhs
=
self
.
rhs
.
evaluate
(
value_fn
)
eval_rhs
=
self
.
rhs
.
evaluate
(
value_fn
)
else
:
else
:
eval_rhs
=
self
.
rhs
eval_rhs
=
cast
(
int
,
self
.
rhs
)
assert
self
.
operation
is
not
None
return
self
.
operation
(
eval_lhs
,
eval_rhs
)
return
self
.
operation
(
eval_lhs
,
eval_rhs
)
def
__repr__
(
self
):
def
__repr__
(
self
):
...
...
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
View file @
05c7d6e9
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
# pylint: skip-file
# pylint: skip-file
# type: ignore
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item.
These are merely some components of the algorithm. The complete support is an undergoing work item.
...
@@ -96,9 +97,9 @@ class DifferentiableSuperConv2d(nn.Conv2d):
...
@@ -96,9 +97,9 @@ class DifferentiableSuperConv2d(nn.Conv2d):
----------
----------
input_weight : Tensor
input_weight : Tensor
the weight to be weighted summed
the weight to be weighted summed
masks :
L
ist[Tensor]
masks :
l
ist[Tensor]
weight masks.
weight masks.
thresholds :
L
ist[float]
thresholds :
l
ist[float]
thresholds, should have a length of ``len(masks) - 1``
thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float]
indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight
take a tensor and a threshold as input, and output the weight
...
...
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
View file @
05c7d6e9
...
@@ -4,19 +4,26 @@
...
@@ -4,19 +4,26 @@
"""Utilities to process the value choice compositions,
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
in the way that is most convenient to one-shot algorithms."""
from
__future__
import
annotations
import
itertools
import
itertools
from
typing
import
List
,
Any
,
Dict
,
Tuple
,
Optional
,
Union
from
typing
import
Any
,
TypeVar
,
List
,
cast
import
numpy
as
np
import
torch
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.retiarii.nn.pytorch.api
import
ChoiceOf
,
ValueChoiceX
Choice
=
Any
Choice
=
Any
T
=
TypeVar
(
'T'
)
__all__
=
[
'dedup_inner_choices'
,
'evaluate_value_choice_with_dict'
,
'traverse_all_options'
]
__all__
=
[
'dedup_inner_choices'
,
'evaluate_value_choice_with_dict'
,
'traverse_all_options'
]
def
dedup_inner_choices
(
value_choices
:
L
ist
[
ValueChoiceX
])
->
D
ict
[
str
,
ParameterSpec
]:
def
dedup_inner_choices
(
value_choices
:
l
ist
[
ValueChoiceX
])
->
d
ict
[
str
,
ParameterSpec
]:
"""Find all leaf nodes in ``value_choices``,
"""Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``.
save them into in the format of ``{label: parameter_spec}``.
"""
"""
...
@@ -33,7 +40,7 @@ def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, Paramete
...
@@ -33,7 +40,7 @@ def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, Paramete
return
result
return
result
def
evaluate_value_choice_with_dict
(
value_choice
:
Value
Choice
X
,
chosen
:
D
ict
[
str
,
Choice
])
->
Any
:
def
evaluate_value_choice_with_dict
(
value_choice
:
Choice
Of
[
T
]
,
chosen
:
d
ict
[
str
,
Choice
])
->
T
:
"""To evaluate a composition of value-choice with a dict,
"""To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``.
with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values,
The implementation is two-pass. We first get a list of values,
...
@@ -56,8 +63,10 @@ def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str
...
@@ -56,8 +63,10 @@ def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str
return
value_choice
.
evaluate
(
choice_inner_values
)
return
value_choice
.
evaluate
(
choice_inner_values
)
def
traverse_all_options
(
value_choice
:
ValueChoiceX
,
def
traverse_all_options
(
weights
:
Optional
[
Dict
[
str
,
List
[
float
]]]
=
None
)
->
List
[
Union
[
Tuple
[
Any
,
float
],
Any
]]:
value_choice
:
ChoiceOf
[
T
],
weights
:
dict
[
str
,
list
[
float
]]
|
dict
[
str
,
np
.
ndarray
]
|
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
)
->
list
[
tuple
[
T
,
float
]]
|
list
[
T
]:
"""Traverse all possible computation outcome of a value choice.
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
...
@@ -65,33 +74,33 @@ def traverse_all_options(value_choice: ValueChoiceX,
...
@@ -65,33 +74,33 @@ def traverse_all_options(value_choice: ValueChoiceX,
----------
----------
value_choice : ValueChoiceX
value_choice : ValueChoiceX
The value choice to traverse.
The value choice to traverse.
weights : Optional[
D
ict[str,
L
ist[float]]], default = None
weights : Optional[
d
ict[str,
l
ist[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
Returns
-------
-------
L
ist[Union[
T
uple[Any, float], Any]]
l
ist[Union[
t
uple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated.
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
Otherwise, it will be a list of options.
"""
"""
# get a dict of {label: list of tuple of choice and weight}
# get a dict of {label: list of tuple of choice and weight}
leafs
:
D
ict
[
str
,
L
ist
[
T
uple
[
Choice
,
float
]]]
=
{}
leafs
:
d
ict
[
str
,
l
ist
[
t
uple
[
T
,
float
]]]
=
{}
for
label
,
param_spec
in
dedup_inner_choices
([
value_choice
]).
items
():
for
label
,
param_spec
in
dedup_inner_choices
([
value_choice
]).
items
():
if
weights
is
not
None
:
if
weights
is
not
None
:
if
label
not
in
weights
:
if
label
not
in
weights
:
raise
KeyError
(
f
'
{
value_choice
}
depends on a weight with key
{
label
}
, but not found in
{
weights
}
'
)
raise
KeyError
(
f
'
{
value_choice
}
depends on a weight with key
{
label
}
, but not found in
{
weights
}
'
)
if
len
(
weights
[
label
])
!=
param_spec
.
size
:
if
len
(
weights
[
label
])
!=
param_spec
.
size
:
raise
KeyError
(
f
'Expect weights with
{
label
}
to be of length
{
param_spec
.
size
}
, but
{
len
(
weights
[
label
])
}
found'
)
raise
KeyError
(
f
'Expect weights with
{
label
}
to be of length
{
param_spec
.
size
}
, but
{
len
(
weights
[
label
])
}
found'
)
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
weights
[
label
]))
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
cast
(
List
[
float
],
weights
[
label
]))
)
else
:
else
:
# create a dummy weight of zero, in case that weights are not provided.
# create a dummy weight of zero, in case that weights are not provided.
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
itertools
.
repeat
(
0.
,
param_spec
.
size
)))
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
itertools
.
repeat
(
0.
,
param_spec
.
size
)))
# result is a dict from a option to its weight
# result is a dict from a option to its weight
result
:
D
ict
[
str
,
Optional
[
float
]
]
=
{}
result
:
d
ict
[
T
,
float
|
None
]
=
{}
labels
,
values
=
list
(
leafs
.
keys
()),
list
(
leafs
.
values
())
labels
,
values
=
list
(
leafs
.
keys
()),
list
(
leafs
.
values
())
if
not
labels
:
if
not
labels
:
...
@@ -126,6 +135,6 @@ def traverse_all_options(value_choice: ValueChoiceX,
...
@@ -126,6 +135,6 @@ def traverse_all_options(value_choice: ValueChoiceX,
result
[
eval_res
]
=
chosen_weight
result
[
eval_res
]
=
chosen_weight
if
weights
is
None
:
if
weights
is
None
:
return
sorted
(
result
.
keys
())
return
sorted
(
result
.
keys
())
# type: ignore
else
:
else
:
return
sorted
(
result
.
items
())
return
sorted
(
result
.
items
())
# type: ignore
nni/retiarii/oneshot/pytorch/supermodule/base.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Any
,
Dict
,
Tuple
,
Union
from
__future__
import
annotations
from
typing
import
Any
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -24,13 +26,13 @@ class BaseSuperNetModule(nn.Module):
...
@@ -24,13 +26,13 @@ class BaseSuperNetModule(nn.Module):
rather than their compositions.
rather than their compositions.
"""
"""
def
resample
(
self
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
resample
(
self
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""
"""
Resample the super-net module.
Resample the super-net module.
Parameters
Parameters
----------
----------
memo :
D
ict[str, Any]
memo :
d
ict[str, Any]
Used to ensure the consistency of samples with the same label.
Used to ensure the consistency of samples with the same label.
Returns
Returns
...
@@ -40,19 +42,19 @@ class BaseSuperNetModule(nn.Module):
...
@@ -40,19 +42,19 @@ class BaseSuperNetModule(nn.Module):
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
export
(
self
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
export
(
self
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""
"""
Export the final architecture within this module.
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
It should have the same keys as ``search_space_spec()``.
Parameters
Parameters
----------
----------
memo :
D
ict[str, Any]
memo :
d
ict[str, Any]
Use memo to avoid the same label gets exported multiple times.
Use memo to avoid the same label gets exported multiple times.
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
search_space_spec
(
self
)
->
D
ict
[
str
,
ParameterSpec
]:
def
search_space_spec
(
self
)
->
d
ict
[
str
,
ParameterSpec
]:
"""
"""
Space specification (sample points).
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
...
@@ -64,8 +66,8 @@ class BaseSuperNetModule(nn.Module):
...
@@ -64,8 +66,8 @@ class BaseSuperNetModule(nn.Module):
raise
NotImplementedError
()
raise
NotImplementedError
()
@
classmethod
@
classmethod
def
mutate
(
cls
,
module
:
nn
.
Module
,
name
:
str
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
\
def
mutate
(
cls
,
module
:
nn
.
Module
,
name
:
str
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
\
Union
[
'BaseSuperNetModule'
,
bool
,
T
uple
[
'BaseSuperNetModule'
,
bool
]
]
:
'BaseSuperNetModule'
|
bool
|
t
uple
[
'BaseSuperNetModule'
,
bool
]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module,
The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on.
because they usually have specific rules about what kind of modules to operate on.
...
@@ -84,7 +86,7 @@ class BaseSuperNetModule(nn.Module):
...
@@ -84,7 +86,7 @@ class BaseSuperNetModule(nn.Module):
Returns
Returns
-------
-------
Union[BaseSuperNetModule, bool,
T
uple[BaseSuperNetModule, bool]]
Union[BaseSuperNetModule, bool,
t
uple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`nni.retiarii.oneshot.pytorch.base.BaseOneShotLightningModule` for details.
See :class:`nni.retiarii.oneshot.pytorch.base.BaseOneShotLightningModule` for details.
"""
"""
...
...
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
functools
import
functools
import
warnings
import
warnings
from
typing
import
List
,
Tuple
,
Optional
,
Dict
,
Any
,
Union
from
typing
import
Any
,
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -21,7 +23,9 @@ from ._valuechoice_utils import traverse_all_options
...
@@ -21,7 +23,9 @@ from ._valuechoice_utils import traverse_all_options
class
GumbelSoftmax
(
nn
.
Softmax
):
class
GumbelSoftmax
(
nn
.
Softmax
):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
def
__init__
(
self
,
dim
:
Optional
[
int
]
=
-
1
)
->
None
:
dim
:
int
def
__init__
(
self
,
dim
:
int
=
-
1
)
->
None
:
super
().
__init__
(
dim
)
super
().
__init__
(
dim
)
self
.
tau
=
1
self
.
tau
=
1
self
.
hard
=
False
self
.
hard
=
False
...
@@ -42,7 +46,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
...
@@ -42,7 +46,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Parameters
Parameters
----------
----------
paths :
L
ist[
T
uple[str, nn.Module]]
paths :
l
ist[
t
uple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
alpha : Tensor
Tensor that stores the "learnable" weights.
Tensor that stores the "learnable" weights.
...
@@ -59,9 +63,9 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
...
@@ -59,9 +63,9 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
Name of the choice.
Name of the choice.
"""
"""
_arch_parameter_names
:
L
ist
[
str
]
=
[
'_arch_alpha'
]
_arch_parameter_names
:
l
ist
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
paths
:
L
ist
[
T
uple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
paths
:
l
ist
[
t
uple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
op_names
=
[]
self
.
op_names
=
[]
if
len
(
alpha
)
!=
len
(
paths
):
if
len
(
alpha
)
!=
len
(
paths
):
...
@@ -82,7 +86,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
...
@@ -82,7 +86,7 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
"""Choose the operator with the maximum logit."""
"""Choose the operator with the maximum logit."""
if
self
.
label
in
memo
:
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
torch
.
argmax
(
self
.
_arch_alpha
).
item
()]}
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
()
)
]}
def
search_space_spec
(
self
):
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
...
@@ -149,9 +153,9 @@ class DifferentiableMixedInput(BaseSuperNetModule):
...
@@ -149,9 +153,9 @@ class DifferentiableMixedInput(BaseSuperNetModule):
Name of the choice.
Name of the choice.
"""
"""
_arch_parameter_names
:
L
ist
[
str
]
=
[
'_arch_alpha'
]
_arch_parameter_names
:
l
ist
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
n_candidates
=
n_candidates
self
.
n_candidates
=
n_candidates
if
len
(
alpha
)
!=
n_candidates
:
if
len
(
alpha
)
!=
n_candidates
:
...
@@ -240,9 +244,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
...
@@ -240,9 +244,9 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
won't be optimized.
won't be optimized.
"""
"""
_arch_parameter_names
:
L
ist
[
str
]
=
[
'_arch_alpha'
]
_arch_parameter_names
:
l
ist
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
None
:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation
.
_arch_alpha
=
nn
.
ParameterDict
()
operation
.
_arch_alpha
=
nn
.
ParameterDict
()
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
...
@@ -254,20 +258,20 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
...
@@ -254,20 +258,20 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
operation
.
_arch_alpha
[
name
]
=
alpha
operation
.
_arch_alpha
[
name
]
=
alpha
operation
.
parameters
=
functools
.
partial
(
self
.
parameters
,
self
=
operation
)
# bind self
operation
.
parameters
=
functools
.
partial
(
self
.
parameters
,
module
=
operation
)
# bind self
operation
.
named_parameters
=
functools
.
partial
(
self
.
named_parameters
,
self
=
operation
)
operation
.
named_parameters
=
functools
.
partial
(
self
.
named_parameters
,
module
=
operation
)
operation
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
operation
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
@
staticmethod
@
staticmethod
def
parameters
(
self
,
*
args
,
**
kwargs
):
def
parameters
(
module
,
*
args
,
**
kwargs
):
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
for
_
,
p
in
module
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
yield
p
@
staticmethod
@
staticmethod
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
def
named_parameters
(
module
,
*
args
,
**
kwargs
):
arch
=
kwargs
.
pop
(
'arch'
,
False
)
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
(
self
.
__class__
,
self
).
named_parameters
(
*
args
,
**
kwargs
):
# pylint: disable=bad-super-call
for
name
,
p
in
super
(
module
.
__class__
,
module
).
named_parameters
(
*
args
,
**
kwargs
):
# pylint: disable=bad-super-call
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
arch
:
if
arch
:
yield
name
,
p
yield
name
,
p
...
@@ -275,22 +279,24 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
...
@@ -275,22 +279,24 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
if
not
arch
:
if
not
arch
:
yield
name
,
p
yield
name
,
p
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""Differentiable. Do nothing in resample."""
"""Differentiable. Do nothing in resample."""
return
{}
return
{}
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
"""Export is also random for each leaf value choice."""
result
=
{}
result
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
name
in
result
:
if
name
in
result
:
continue
continue
chosen_index
=
torch
.
argmax
(
operation
.
_arch_alpha
[
name
]).
item
()
chosen_index
=
int
(
torch
.
argmax
(
cast
(
dict
,
operation
.
_arch_alpha
)
[
name
]).
item
()
)
result
[
name
]
=
spec
.
values
[
chosen_index
]
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
return
result
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
Union
[
D
ict
[
Any
,
float
]
,
Any
]
:
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
d
ict
[
Any
,
float
]
|
Any
:
if
name
in
operation
.
mutable_arguments
:
if
name
in
operation
.
mutable_arguments
:
weights
=
{
label
:
operation
.
_softmax
(
alpha
)
for
label
,
alpha
in
operation
.
_arch_alpha
.
items
()}
weights
:
dict
[
str
,
torch
.
Tensor
]
=
{
label
:
cast
(
nn
.
Module
,
operation
.
_softmax
)(
alpha
)
for
label
,
alpha
in
cast
(
dict
,
operation
.
_arch_alpha
).
items
()
}
return
dict
(
traverse_all_options
(
operation
.
mutable_arguments
[
name
],
weights
=
weights
))
return
dict
(
traverse_all_options
(
operation
.
mutable_arguments
[
name
],
weights
=
weights
))
return
operation
.
init_arguments
[
name
]
return
operation
.
init_arguments
[
name
]
nni/retiarii/oneshot/pytorch/supermodule/operation.py
View file @
05c7d6e9
...
@@ -6,9 +6,11 @@ Operations that support weight sharing at a fine-grained level,
...
@@ -6,9 +6,11 @@ Operations that support weight sharing at a fine-grained level,
which is commonly known as super-kernel (as in channel search), or weight entanglement.
which is commonly known as super-kernel (as in channel search), or weight entanglement.
"""
"""
from
__future__
import
annotations
import
inspect
import
inspect
import
itertools
import
itertools
from
typing
import
Union
,
Tuple
,
Dict
,
List
,
Any
,
Type
,
Optional
,
TypeVar
,
cast
from
typing
import
Any
,
Type
,
TypeVar
,
cast
,
Union
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -37,7 +39,7 @@ class MixedOperationSamplingPolicy:
...
@@ -37,7 +39,7 @@ class MixedOperationSamplingPolicy:
One SamplingStrategy corresponds to one mixed operation.
One SamplingStrategy corresponds to one mixed operation.
"""
"""
def
__init__
(
self
,
operation
:
'MixedOperation'
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
operation
:
'MixedOperation'
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
None
:
"""At init, the sampling policy can prepare basic parameters,
"""At init, the sampling policy can prepare basic parameters,
and store them in operation if they need back propagation.
and store them in operation if they need back propagation.
...
@@ -47,11 +49,11 @@ class MixedOperationSamplingPolicy:
...
@@ -47,11 +49,11 @@ class MixedOperationSamplingPolicy:
"""
"""
pass
pass
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.resample`."""
"""The handler of :meth:`MixedOperation.resample`."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.export`."""
"""The handler of :meth:`MixedOperation.export`."""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -90,7 +92,7 @@ class MixedOperation(BaseSuperNetModule):
...
@@ -90,7 +92,7 @@ class MixedOperation(BaseSuperNetModule):
"""
"""
bound_type
:
Type
[
nn
.
Module
]
# defined in subclass
bound_type
:
Type
[
nn
.
Module
]
# defined in subclass
argument_list
:
L
ist
[
str
]
# defined in subclass
argument_list
:
l
ist
[
str
]
# defined in subclass
sampling_policy
:
MixedOperationSamplingPolicy
sampling_policy
:
MixedOperationSamplingPolicy
...
@@ -114,11 +116,11 @@ class MixedOperation(BaseSuperNetModule):
...
@@ -114,11 +116,11 @@ class MixedOperation(BaseSuperNetModule):
appended by forward arguments in the ``bound_type``."""
appended by forward arguments in the ``bound_type``."""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
__init__
(
self
,
module_kwargs
:
D
ict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
module_kwargs
:
d
ict
[
str
,
Any
])
->
None
:
# Concerned arguments
# Concerned arguments
self
.
mutable_arguments
:
D
ict
[
str
,
ValueChoiceX
]
=
{}
self
.
mutable_arguments
:
d
ict
[
str
,
ValueChoiceX
]
=
{}
# Useful when retrieving arguments without ValueChoice
# Useful when retrieving arguments without ValueChoice
self
.
init_arguments
:
D
ict
[
str
,
Any
]
=
{
**
module_kwargs
}
self
.
init_arguments
:
d
ict
[
str
,
Any
]
=
{
**
module_kwargs
}
self
.
_fill_missing_init_arguments
()
self
.
_fill_missing_init_arguments
()
# get init default
# get init default
...
@@ -134,7 +136,7 @@ class MixedOperation(BaseSuperNetModule):
...
@@ -134,7 +136,7 @@ class MixedOperation(BaseSuperNetModule):
super_init_kwargs
[
key
]
=
value
super_init_kwargs
[
key
]
=
value
# get all inner leaf value choices
# get all inner leaf value choices
self
.
_space_spec
:
D
ict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
(
self
.
mutable_arguments
.
values
())
self
.
_space_spec
:
d
ict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
(
list
(
self
.
mutable_arguments
.
values
())
)
super
().
__init__
(
**
super_init_kwargs
)
super
().
__init__
(
**
super_init_kwargs
)
...
@@ -156,7 +158,7 @@ class MixedOperation(BaseSuperNetModule):
...
@@ -156,7 +158,7 @@ class MixedOperation(BaseSuperNetModule):
"""Find value choice in module's arguments and replace the whole module"""
"""Find value choice in module's arguments and replace the whole module"""
has_valuechoice
=
False
has_valuechoice
=
False
if
isinstance
(
module
,
cls
.
bound_type
)
and
is_traceable
(
module
):
if
isinstance
(
module
,
cls
.
bound_type
)
and
is_traceable
(
module
):
for
arg
in
itertools
.
chain
(
module
.
trace_args
,
module
.
trace_kwargs
.
values
()):
for
arg
in
itertools
.
chain
(
cast
(
list
,
module
.
trace_args
),
cast
(
dict
,
module
.
trace_kwargs
)
.
values
()):
if
isinstance
(
arg
,
ValueChoiceX
):
if
isinstance
(
arg
,
ValueChoiceX
):
has_valuechoice
=
True
has_valuechoice
=
True
...
@@ -166,7 +168,7 @@ class MixedOperation(BaseSuperNetModule):
...
@@ -166,7 +168,7 @@ class MixedOperation(BaseSuperNetModule):
'Please enable ``kw_only`` on nni.trace.'
)
'Please enable ``kw_only`` on nni.trace.'
)
# save type and kwargs
# save type and kwargs
mixed_op
=
cls
(
module
.
trace_kwargs
)
mixed_op
=
cls
(
cast
(
dict
,
module
.
trace_kwargs
)
)
if
'mixed_op_sampling'
not
in
mutate_kwargs
:
if
'mixed_op_sampling'
not
in
mutate_kwargs
:
raise
ValueError
(
'Need to sampling policy of mixed op, but not found in `mutate_kwargs`.'
)
raise
ValueError
(
'Need to sampling policy of mixed op, but not found in `mutate_kwargs`.'
)
...
@@ -229,15 +231,15 @@ class MixedLinear(MixedOperation, nn.Linear):
...
@@ -229,15 +231,15 @@ class MixedLinear(MixedOperation, nn.Linear):
out_features
:
int_or_int_dict
,
out_features
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
in_features
=
_W
(
in_features
)
in_features
_
=
_W
(
in_features
)
out_features
=
_W
(
out_features
)
out_features
_
=
_W
(
out_features
)
weight
=
_S
(
self
.
weight
)[:
out_features
]
weight
=
_S
(
self
.
weight
)[:
out_features
_
]
weight
=
_S
(
weight
)[:,
:
in_features
]
weight
=
_S
(
weight
)[:,
:
in_features
_
]
if
self
.
bias
is
None
:
if
self
.
bias
is
None
:
bias
=
self
.
bias
bias
=
self
.
bias
else
:
else
:
bias
=
_S
(
self
.
bias
)[:
out_features
]
bias
=
_S
(
self
.
bias
)[:
out_features
_
]
return
F
.
linear
(
inputs
,
weight
,
bias
)
return
F
.
linear
(
inputs
,
weight
,
bias
)
...
@@ -278,7 +280,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
...
@@ -278,7 +280,7 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
]
]
@
staticmethod
@
staticmethod
def
_to_tuple
(
value
:
scalar_or_scalar_dict
[
T
])
->
T
uple
[
T
,
T
]:
def
_to_tuple
(
value
:
scalar_or_scalar_dict
[
Any
])
->
t
uple
[
Any
,
Any
]:
if
not
isinstance
(
value
,
tuple
):
if
not
isinstance
(
value
,
tuple
):
return
(
value
,
value
)
return
(
value
,
value
)
return
value
return
value
...
@@ -318,33 +320,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
...
@@ -318,33 +320,37 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
,
groups
]):
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
,
groups
]):
raise
ValueError
(
'stride, dilation, groups does not support weighted sampling.'
)
raise
ValueError
(
'stride, dilation, groups does not support weighted sampling.'
)
in_channels
=
_W
(
in_channels
)
in_channels
_
=
_W
(
in_channels
)
out_channels
=
_W
(
out_channels
)
out_channels
_
=
_W
(
out_channels
)
# slice prefix
# slice prefix
# For groups > 1, we use groups to slice input weights
# For groups > 1, we use groups to slice input weights
weight
=
_S
(
self
.
weight
)[:
out_channels
]
weight
=
_S
(
self
.
weight
)[:
out_channels
_
]
weight
=
_S
(
weight
)[:,
:
in_channels
//
groups
]
weight
=
_S
(
weight
)[:,
:
in_channels
_
//
groups
]
# slice center
# slice center
if
isinstance
(
kernel_size
,
dict
):
if
isinstance
(
kernel_size
,
dict
):
# If kernel size is a dict, ignore choices in padding.
if
isinstance
(
self
.
padding
,
str
):
raise
ValueError
(
f
'Use "
{
self
.
padding
}
" in padding is not supported.'
)
padding
=
self
.
padding
# max padding, must be a tuple
padding
=
self
.
padding
# max padding, must be a tuple
kernel_a
,
kernel_b
=
self
.
_to_tuple
(
kernel_size
)
kernel_a
,
kernel_b
=
self
.
_to_tuple
(
kernel_size
)
kernel_a
,
kernel_b
=
_W
(
kernel_a
),
_W
(
kernel_b
)
kernel_a
_
,
kernel_b
_
=
_W
(
kernel_a
),
_W
(
kernel_b
)
max_kernel_a
,
max_kernel_b
=
self
.
kernel_size
# self.kernel_size must be a tuple
max_kernel_a
,
max_kernel_b
=
self
.
kernel_size
# self.kernel_size must be a tuple
kernel_a_left
,
kernel_b_top
=
(
max_kernel_a
-
kernel_a
)
//
2
,
(
max_kernel_b
-
kernel_b
)
//
2
kernel_a_left
,
kernel_b_top
=
(
max_kernel_a
-
kernel_a
_
)
//
2
,
(
max_kernel_b
-
kernel_b
_
)
//
2
weight
=
_S
(
weight
)[:,
:,
kernel_a_left
:
kernel_a_left
+
kernel_a
,
kernel_b_top
:
kernel_b_top
+
kernel_b
]
weight
=
_S
(
weight
)[:,
:,
kernel_a_left
:
kernel_a_left
+
kernel_a
_
,
kernel_b_top
:
kernel_b_top
+
kernel_b
_
]
bias
=
_S
(
self
.
bias
)[:
out_channels
]
if
self
.
bias
is
not
None
else
None
bias
=
_S
(
self
.
bias
)[:
out_channels
_
]
if
self
.
bias
is
not
None
else
None
# The rest parameters only need to be converted to tuple
# The rest parameters only need to be converted to tuple
stride
=
self
.
_to_tuple
(
stride
)
stride
_
=
self
.
_to_tuple
(
stride
)
dilation
=
self
.
_to_tuple
(
dilation
)
dilation
_
=
self
.
_to_tuple
(
dilation
)
if
self
.
padding_mode
!=
'zeros'
:
if
self
.
padding_mode
!=
'zeros'
:
return
F
.
conv2d
(
F
.
pad
(
inputs
,
self
.
_reversed_padding_repeated_twice
,
mode
=
self
.
padding_mode
),
return
F
.
conv2d
(
F
.
pad
(
inputs
,
self
.
_reversed_padding_repeated_twice
,
mode
=
self
.
padding_mode
),
weight
,
bias
,
stride
,
(
0
,
0
),
dilation
,
groups
)
weight
,
bias
,
stride
_
,
(
0
,
0
),
dilation
_
,
groups
)
return
F
.
conv2d
(
inputs
,
weight
,
bias
,
stride
,
padding
,
dilation
,
groups
)
return
F
.
conv2d
(
inputs
,
weight
,
bias
,
stride
_
,
cast
(
'int | tuple'
,
padding
)
,
dilation
_
,
groups
)
class
MixedBatchNorm2d
(
MixedOperation
,
nn
.
BatchNorm2d
):
class
MixedBatchNorm2d
(
MixedOperation
,
nn
.
BatchNorm2d
):
...
@@ -388,13 +394,15 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
...
@@ -388,13 +394,15 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
if
num_features
<
self
.
num_features
:
if
num_features
<
self
.
num_features
:
weight
=
weight
[:
num_features
]
weight
=
weight
[:
num_features
]
bias
=
bias
[:
num_features
]
bias
=
bias
[:
num_features
]
if
running_mean
is
not
None
:
running_mean
=
running_mean
[:
num_features
]
running_mean
=
running_mean
[:
num_features
]
if
running_var
is
not
None
:
running_var
=
running_var
[:
num_features
]
running_var
=
running_var
[:
num_features
]
if
self
.
training
:
if
self
.
training
:
bn_training
=
True
bn_training
=
True
else
:
else
:
bn_training
=
(
self
.
running_mean
is
None
)
and
(
self
.
running_var
is
None
)
bn_training
=
(
running_mean
is
None
)
and
(
running_var
is
None
)
return
F
.
batch_norm
(
return
F
.
batch_norm
(
inputs
,
inputs
,
...
@@ -473,7 +481,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -473,7 +481,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
return
max
(
traverse_all_options
(
value_choice
))
def
_to_proj_slice
(
self
,
embed_dim
:
_W
)
->
L
ist
[
slice
]:
def
_to_proj_slice
(
self
,
embed_dim
:
_W
)
->
l
ist
[
slice
]:
# slice three parts, corresponding to q, k, v respectively
# slice three parts, corresponding to q, k, v respectively
return
[
return
[
slice
(
embed_dim
),
slice
(
embed_dim
),
...
@@ -484,12 +492,12 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -484,12 +492,12 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
def
forward_with_args
(
def
forward_with_args
(
self
,
self
,
embed_dim
:
int_or_int_dict
,
num_heads
:
int
,
embed_dim
:
int_or_int_dict
,
num_heads
:
int
,
kdim
:
Optional
[
int_or_int_dict
]
,
vdim
:
Optional
[
int_or_int_dict
]
,
kdim
:
int_or_int_dict
|
None
,
vdim
:
int_or_int_dict
|
None
,
dropout
:
float
,
dropout
:
float
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
key_padding_mask
:
torch
.
Tensor
|
None
=
None
,
need_weights
:
bool
=
True
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
need_weights
:
bool
=
True
,
attn_mask
:
torch
.
Tensor
|
None
=
None
)
->
T
uple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
num_heads
,
dropout
]):
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
num_heads
,
dropout
]):
raise
ValueError
(
'num_heads, dropout do not support weighted sampling.'
)
raise
ValueError
(
'num_heads, dropout do not support weighted sampling.'
)
...
@@ -511,26 +519,26 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -511,26 +519,26 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
else
:
else
:
used_embed_dim
=
embed_dim
used_embed_dim
=
embed_dim
embed_dim
=
_W
(
embed_dim
)
embed_dim
_
=
_W
(
embed_dim
)
# in projection weights & biases has q, k, v weights concatenated together
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias
:
Optional
[
Tensor
]
=
None
in_proj_bias
:
Tensor
|
None
=
None
in_proj_weight
:
Optional
[
Tensor
]
=
None
in_proj_weight
:
Tensor
|
None
=
None
if
self
.
in_proj_bias
is
not
None
:
if
self
.
in_proj_bias
is
not
None
:
in_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
in_proj_bias
))[
self
.
_to_proj_slice
(
embed_dim
)]
in_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
in_proj_bias
))[
self
.
_to_proj_slice
(
embed_dim
_
)]
if
self
.
in_proj_weight
is
not
None
:
if
self
.
in_proj_weight
is
not
None
:
in_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
in_proj_weight
))[
self
.
_to_proj_slice
(
embed_dim
),
:
embed_dim
]
in_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
in_proj_weight
))[
self
.
_to_proj_slice
(
embed_dim
_
),
:
embed_dim
_
]
bias_k
=
_S
(
cast
(
Tensor
,
self
.
bias_k
))[:,
:,
:
embed_dim
]
if
self
.
bias_k
is
not
None
else
None
bias_k
=
_S
(
cast
(
Tensor
,
self
.
bias_k
))[:,
:,
:
embed_dim
_
]
if
self
.
bias_k
is
not
None
else
None
bias_v
=
_S
(
cast
(
Tensor
,
self
.
bias_v
))[:,
:,
:
embed_dim
]
if
self
.
bias_v
is
not
None
else
None
bias_v
=
_S
(
cast
(
Tensor
,
self
.
bias_v
))[:,
:,
:
embed_dim
_
]
if
self
.
bias_v
is
not
None
else
None
out_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
weight
))[:
embed_dim
,
:
embed_dim
]
out_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
weight
))[:
embed_dim
_
,
:
embed_dim
_
]
out_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
bias
))[:
embed_dim
]
if
self
.
out_proj
.
bias
is
not
None
else
None
out_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
bias
))[:
embed_dim
_
]
if
self
.
out_proj
.
bias
is
not
None
else
None
if
not
qkv_same_embed_dim
:
if
not
qkv_same_embed_dim
:
q_proj
=
_S
(
cast
(
Tensor
,
self
.
q_proj_weight
))[:
embed_dim
,
:
embed_dim
]
q_proj
=
_S
(
cast
(
Tensor
,
self
.
q_proj_weight
))[:
embed_dim
_
,
:
embed_dim
_
]
k_proj
=
_S
(
cast
(
Tensor
,
self
.
k_proj_weight
))[:
embed_dim
]
k_proj
=
_S
(
cast
(
Tensor
,
self
.
k_proj_weight
))[:
embed_dim
_
]
k_proj
=
_S
(
k_proj
)[:,
:
_W
(
kdim
)]
k_proj
=
_S
(
k_proj
)[:,
:
_W
(
kdim
)]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim
]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim
_
]
v_proj
=
_S
(
v_proj
)[:,
:
_W
(
vdim
)]
v_proj
=
_S
(
v_proj
)[:,
:
_W
(
vdim
)]
# The rest part is basically same as pytorch
# The rest part is basically same as pytorch
...
@@ -560,7 +568,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -560,7 +568,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
return
attn_output
,
attn_output_weights
return
attn_output
,
attn_output_weights
NATIVE_MIXED_OPERATIONS
:
L
ist
[
Type
[
MixedOperation
]]
=
[
NATIVE_MIXED_OPERATIONS
:
l
ist
[
Type
[
MixedOperation
]]
=
[
MixedLinear
,
MixedLinear
,
MixedConv2d
,
MixedConv2d
,
MixedBatchNorm2d
,
MixedBatchNorm2d
,
...
...
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
View file @
05c7d6e9
...
@@ -9,7 +9,9 @@ The support remains limited. Known limitations include:
...
@@ -9,7 +9,9 @@ The support remains limited. Known limitations include:
- The code contains duplicates. Needs refactor.
- The code contains duplicates. Needs refactor.
"""
"""
from
typing
import
List
,
Tuple
,
Optional
,
cast
from
__future__
import
annotations
from
typing
import
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -48,13 +50,13 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
...
@@ -48,13 +50,13 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
paths
:
L
ist
[
T
uple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
paths
:
l
ist
[
t
uple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
paths
,
alpha
,
softmax
,
label
)
super
().
__init__
(
paths
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
paths
))
*
1E-3
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
paths
))
*
1E-3
)
# like sampling-based methods, it has a ``_sampled``.
# like sampling-based methods, it has a ``_sampled``.
self
.
_sampled
:
Optional
[
str
]
=
None
self
.
_sampled
:
str
|
None
=
None
self
.
_sample_idx
:
Optional
[
int
]
=
None
self
.
_sample_idx
:
int
|
None
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
def
run_function
(
ops
,
active_id
,
**
kwargs
):
def
run_function
(
ops
,
active_id
,
**
kwargs
):
...
@@ -130,10 +132,10 @@ class ProxylessMixedInput(DifferentiableMixedInput):
...
@@ -130,10 +132,10 @@ class ProxylessMixedInput(DifferentiableMixedInput):
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
n_candidates
,
n_chosen
,
alpha
,
softmax
,
label
)
super
().
__init__
(
n_candidates
,
n_chosen
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
n_candidates
)
*
1E-3
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
n_candidates
)
*
1E-3
)
self
.
_sampled
:
Optional
[
int
]
=
None
self
.
_sampled
:
int
|
None
=
None
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
):
def
run_function
(
active_sample
):
def
run_function
(
active_sample
):
...
...
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
random
import
random
from
typing
import
Optional
,
List
,
Tuple
,
Union
,
Dict
,
Any
from
typing
import
Any
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -28,14 +30,14 @@ class PathSamplingLayer(BaseSuperNetModule):
...
@@ -28,14 +30,14 @@ class PathSamplingLayer(BaseSuperNetModule):
Name of the choice.
Name of the choice.
"""
"""
def
__init__
(
self
,
paths
:
L
ist
[
T
uple
[
str
,
nn
.
Module
]],
label
:
str
):
def
__init__
(
self
,
paths
:
l
ist
[
t
uple
[
str
,
nn
.
Module
]],
label
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
op_names
=
[]
self
.
op_names
=
[]
for
name
,
module
in
paths
:
for
name
,
module
in
paths
:
self
.
add_module
(
name
,
module
)
self
.
add_module
(
name
,
module
)
self
.
op_names
.
append
(
name
)
self
.
op_names
.
append
(
name
)
assert
self
.
op_names
,
'There has to be at least one op to choose from.'
assert
self
.
op_names
,
'There has to be at least one op to choose from.'
self
.
_sampled
:
Optional
[
Union
[
L
ist
[
str
]
,
str
]]
=
None
# sampled can be either a list of indices or an index
self
.
_sampled
:
l
ist
[
str
]
|
str
|
None
=
None
# sampled can be either a list of indices or an index
self
.
label
=
label
self
.
label
=
label
def
resample
(
self
,
memo
):
def
resample
(
self
,
memo
):
...
@@ -89,7 +91,7 @@ class PathSamplingInput(BaseSuperNetModule):
...
@@ -89,7 +91,7 @@ class PathSamplingInput(BaseSuperNetModule):
self
.
n_candidates
=
n_candidates
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
_sampled
:
Optional
[
Union
[
L
ist
[
int
]
,
int
]]
=
None
self
.
_sampled
:
l
ist
[
int
]
|
int
|
None
=
None
self
.
label
=
label
self
.
label
=
label
def
_random_choose_n
(
self
):
def
_random_choose_n
(
self
):
...
@@ -159,11 +161,11 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
...
@@ -159,11 +161,11 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
We sample the leaf nodes, and composits them into the values on arguments.
We sample the leaf nodes, and composits them into the values on arguments.
"""
"""
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
],
mutate_kwargs
:
D
ict
[
str
,
Any
])
->
None
:
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
],
mutate_kwargs
:
d
ict
[
str
,
Any
])
->
None
:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self
.
_sampled
:
Optional
[
D
ict
[
str
,
Any
]
]
=
None
self
.
_sampled
:
d
ict
[
str
,
Any
]
|
None
=
None
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""Random sample for each leaf value choice."""
"""Random sample for each leaf value choice."""
result
=
{}
result
=
{}
space_spec
=
operation
.
search_space_spec
()
space_spec
=
operation
.
search_space_spec
()
...
@@ -181,7 +183,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
...
@@ -181,7 +183,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return
result
return
result
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
"""Export is also random for each leaf value choice."""
result
=
{}
result
=
{}
space_spec
=
operation
.
search_space_spec
()
space_spec
=
operation
.
search_space_spec
()
...
...
nni/retiarii/oneshot/pytorch/utils.py
View file @
05c7d6e9
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
cast
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
,
Dataset
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
...
@@ -155,7 +159,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
...
@@ -155,7 +159,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
Returns
Returns
-------
-------
L
ist[
T
uple[str, nn.Module]]
l
ist[
t
uple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
A list from layer choice keys (names) and replaced modules.
"""
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
LayerChoice
,
nn
.
LayerChoice
),
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
LayerChoice
,
nn
.
LayerChoice
),
modules
)
...
@@ -176,7 +180,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
...
@@ -176,7 +180,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
Returns
Returns
-------
-------
L
ist[
T
uple[str, nn.Module]]
l
ist[
t
uple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
A list from layer choice keys (names) and replaced modules.
"""
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
InputChoice
,
nn
.
InputChoice
),
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
InputChoice
,
nn
.
InputChoice
),
modules
)
...
@@ -200,15 +204,19 @@ class InterleavedTrainValDataLoader(DataLoader):
...
@@ -200,15 +204,19 @@ class InterleavedTrainValDataLoader(DataLoader):
Example
Example
--------
--------
Fit your dataloaders into a parallel one.
Fit your dataloaders into a parallel one.
>>> para_loader = InterleavedTrainValDataLoader(train_dataloader, val_dataloader)
>>> para_loader = InterleavedTrainValDataLoader(train_dataloader, val_dataloader)
Then you can use the ``para_loader`` as a normal training loader.
Then you can use the ``para_loader`` as a normal training loader.
"""
"""
def
__init__
(
self
,
train_dataloader
,
val_dataloader
):
def
__init__
(
self
,
train_dataloader
:
DataLoader
,
val_dataloader
:
DataLoader
|
list
[
DataLoader
]):
self
.
train_loader
=
train_dataloader
if
isinstance
(
val_dataloader
,
list
):
self
.
val_loader
=
val_dataloader
raise
TypeError
(
'Validation dataloader of type list is not supported.'
)
self
.
train_loader
:
DataLoader
=
train_dataloader
self
.
val_loader
:
DataLoader
=
val_dataloader
self
.
equal_len
=
len
(
train_dataloader
)
==
len
(
val_dataloader
)
self
.
equal_len
=
len
(
train_dataloader
)
==
len
(
val_dataloader
)
self
.
train_longer
=
len
(
train_dataloader
)
>
len
(
val_dataloader
)
self
.
train_longer
=
len
(
train_dataloader
)
>
len
(
val_dataloader
)
super
().
__init__
(
None
)
super
().
__init__
(
cast
(
Dataset
,
None
)
)
def
__iter__
(
self
):
def
__iter__
(
self
):
self
.
train_iter
=
iter
(
self
.
train_loader
)
self
.
train_iter
=
iter
(
self
.
train_loader
)
...
@@ -268,13 +276,17 @@ class ConcatenateTrainValDataLoader(DataLoader):
...
@@ -268,13 +276,17 @@ class ConcatenateTrainValDataLoader(DataLoader):
Example
Example
--------
--------
Fit your dataloaders into a concatenated one.
Fit your dataloaders into a concatenated one.
>>> concat_loader = ConcatenateTrainValDataLoader(train_dataloader, val_datalodaer)
>>> concat_loader = ConcatenateTrainValDataLoader(train_dataloader, val_datalodaer)
Then you can use the ``concat_loader`` as a normal training loader.
Then you can use the ``concat_loader`` as a normal training loader.
"""
"""
def
__init__
(
self
,
train_dataloader
,
val_dataloader
):
def
__init__
(
self
,
train_dataloader
:
DataLoader
,
val_dataloader
:
DataLoader
|
list
[
DataLoader
]):
self
.
train_loader
=
train_dataloader
if
isinstance
(
val_dataloader
,
list
):
self
.
val_loader
=
val_dataloader
raise
TypeError
(
'Validation dataloader of type list is not supported.'
)
super
().
__init__
(
None
)
self
.
train_loader
:
DataLoader
=
train_dataloader
self
.
val_loader
:
DataLoader
=
val_dataloader
super
().
__init__
(
cast
(
Dataset
,
None
))
def
__iter__
(
self
):
def
__iter__
(
self
):
self
.
cur_iter
=
iter
(
self
.
train_loader
)
self
.
cur_iter
=
iter
(
self
.
train_loader
)
...
...
pyrightconfig.json
View file @
05c7d6e9
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
"nni/retiarii/execution/cgo_engine.py"
,
"nni/retiarii/execution/cgo_engine.py"
,
"nni/retiarii/execution/logical_optimizer"
,
"nni/retiarii/execution/logical_optimizer"
,
"nni/retiarii/evaluator/pytorch/cgo"
,
"nni/retiarii/evaluator/pytorch/cgo"
,
"nni/retiarii/oneshot"
,
"nni/smartparam.py"
,
"nni/smartparam.py"
,
"nni/tools/annotation"
,
"nni/tools/annotation"
,
"nni/tools/gpu_tool"
,
"nni/tools/gpu_tool"
,
...
...
test/ut/retiarii/test_lightning_trainer.py
View file @
05c7d6e9
...
@@ -130,13 +130,16 @@ def test_fit_api():
...
@@ -130,13 +130,16 @@ def test_fit_api():
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
test_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
test_dataset
=
nni
.
trace
(
MNIST
)(
root
=
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
def
lightning
():
return
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.1
,
# for faster training
max_epochs
=
1
,
limit_train_batches
=
0.1
,
# for faster training
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
progress_bar_refresh_rate
=
progress_bar_refresh_rate
)
lightning
.
fit
(
lambda
:
MNISTModel
())
# Lightning will have some cache in models / trainers,
lightning
.
fit
(
MNISTModel
)
# which is problematic if we call fit multiple times.
lightning
.
fit
(
MNISTModel
())
lightning
().
fit
(
lambda
:
MNISTModel
())
lightning
().
fit
(
MNISTModel
)
lightning
().
fit
(
MNISTModel
())
_reset
()
_reset
()
...
...
test/ut/retiarii/test_oneshot.py
View file @
05c7d6e9
...
@@ -12,6 +12,7 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
...
@@ -12,6 +12,7 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
Regression
,
DataLoader
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
Regression
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ValueChoice
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ValueChoice
from
nni.retiarii.strategy
import
BaseStrategy
class
DepthwiseSeparableConv
(
nn
.
Module
):
class
DepthwiseSeparableConv
(
nn
.
Module
):
...
@@ -237,8 +238,12 @@ def _test_strategy(strategy_, support_value_choice=True):
...
@@ -237,8 +238,12 @@ def _test_strategy(strategy_, support_value_choice=True):
]
]
for
(
base_model
,
evaluator
),
support_or_not
in
to_test
:
for
(
base_model
,
evaluator
),
support_or_not
in
to_test
:
print
(
'Testing:'
,
type
(
strategy_
).
__name__
,
type
(
base_model
).
__name__
,
type
(
evaluator
).
__name__
,
support_or_not
)
if
isinstance
(
strategy_
,
BaseStrategy
):
experiment
=
RetiariiExperiment
(
base_model
,
evaluator
,
strategy
=
strategy_
)
strategy
=
strategy_
else
:
strategy
=
strategy_
(
base_model
,
evaluator
)
print
(
'Testing:'
,
type
(
strategy
).
__name__
,
type
(
base_model
).
__name__
,
type
(
evaluator
).
__name__
,
support_or_not
)
experiment
=
RetiariiExperiment
(
base_model
,
evaluator
,
strategy
=
strategy
)
config
=
RetiariiExeConfig
()
config
=
RetiariiExeConfig
()
config
.
execution_engine
=
'oneshot'
config
.
execution_engine
=
'oneshot'
...
@@ -263,7 +268,12 @@ def test_proxyless():
...
@@ -263,7 +268,12 @@ def test_proxyless():
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_enas
():
def
test_enas
():
_test_strategy
(
strategy
.
ENAS
())
def
strategy_fn
(
base_model
,
evaluator
):
if
isinstance
(
base_model
,
MultiHeadAttentionNet
):
return
strategy
.
ENAS
(
reward_metric_name
=
'val_mse'
)
return
strategy
.
ENAS
(
reward_metric_name
=
'val_acc'
)
_test_strategy
(
strategy_fn
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment