Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
14d2966b
"...mpu/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "13d90b3696037777d2458dd482cd27e1f11f1356"
Unverified
Commit
14d2966b
authored
Mar 29, 2022
by
Frandium
Committed by
GitHub
Mar 29, 2022
Browse files
Valuechoice oneshot lightning (#4602)
parent
5b7dac5c
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
2878 additions
and
518 deletions
+2878
-518
nni/common/hpo_utils/formatting.py
nni/common/hpo_utils/formatting.py
+2
-0
nni/retiarii/oneshot/pytorch/__init__.py
nni/retiarii/oneshot/pytorch/__init__.py
+2
-2
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+219
-86
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+80
-276
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+81
-77
nni/retiarii/oneshot/pytorch/strategy.py
nni/retiarii/oneshot/pytorch/strategy.py
+15
-13
nni/retiarii/oneshot/pytorch/supermodule/__init__.py
nni/retiarii/oneshot/pytorch/supermodule/__init__.py
+0
-0
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
+279
-0
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
+261
-0
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
...etiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
+131
-0
nni/retiarii/oneshot/pytorch/supermodule/base.py
nni/retiarii/oneshot/pytorch/supermodule/base.py
+91
-0
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
+296
-0
nni/retiarii/oneshot/pytorch/supermodule/operation.py
nni/retiarii/oneshot/pytorch/supermodule/operation.py
+568
-0
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
+191
-0
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
+196
-0
nni/retiarii/strategy/__init__.py
nni/retiarii/strategy/__init__.py
+1
-1
nni/retiarii/strategy/oneshot.py
nni/retiarii/strategy/oneshot.py
+2
-2
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+215
-61
test/ut/retiarii/test_oneshot_supermodules.py
test/ut/retiarii/test_oneshot_supermodules.py
+248
-0
No files found.
nni/common/hpo_utils/formatting.py
View file @
14d2966b
...
@@ -39,6 +39,8 @@ class ParameterSpec(NamedTuple):
...
@@ -39,6 +39,8 @@ class ParameterSpec(NamedTuple):
categorical
:
bool
# Whether this paramter is categorical (unordered) or numerical (ordered)
categorical
:
bool
# Whether this paramter is categorical (unordered) or numerical (ordered)
size
:
int
=
None
# If it's categorical, how many candidates it has
size
:
int
=
None
# If it's categorical, how many candidates it has
chosen_size
:
Optional
[
int
]
=
1
# If it's categorical, it should choose how many candidates.
# By default, 1. If none, arbitrary number of candidates can be chosen.
# uniform distributed
# uniform distributed
low
:
float
=
None
# Lower bound of uniform parameter
low
:
float
=
None
# Lower bound of uniform parameter
...
...
nni/retiarii/oneshot/pytorch/__init__.py
View file @
14d2966b
...
@@ -5,6 +5,6 @@ from .darts import DartsTrainer
...
@@ -5,6 +5,6 @@ from .darts import DartsTrainer
from
.enas
import
EnasTrainer
from
.enas
import
EnasTrainer
from
.proxyless
import
ProxylessTrainer
from
.proxyless
import
ProxylessTrainer
from
.random
import
SinglePathTrainer
,
RandomTrainer
from
.random
import
SinglePathTrainer
,
RandomTrainer
from
.differentiable
import
DartsModule
,
Proxyless
Module
,
Snas
Module
from
.differentiable
import
Darts
Lightning
Module
,
Proxyless
LightningModule
,
GumbelDartsLightning
Module
from
.sampling
import
EnasModule
,
RandomSamplingModule
from
.sampling
import
Enas
Lightning
Module
,
RandomSampling
Lightning
Module
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Dict
,
Type
,
Callable
,
List
,
Optional
import
warnings
from
itertools
import
chain
from
typing
import
Dict
,
Callable
,
List
,
Union
,
Any
,
Tuple
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch.optim
as
optim
import
torch.optim
as
optim
...
@@ -9,51 +11,163 @@ import torch.nn as nn
...
@@ -9,51 +11,163 @@ import torch.nn as nn
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.optim.lr_scheduler
import
_LRScheduler
ReplaceDictType
=
Dict
[
Type
[
nn
.
Module
],
Callable
[[
nn
.
Module
],
nn
.
Module
]]
import
nni.retiarii.nn.pytorch
as
nas_nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
is_traceable
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
.supermodule.base
import
BaseSuperNetModule
__all__
=
[
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
'traverse_and_mutate_submodules'
]
def
_replace_module_with_type
(
root_module
:
nn
.
Module
,
replace_dict
:
ReplaceDictType
,
modules
:
List
[
nn
.
Module
]):
MutationHook
=
Callable
[[
nn
.
Module
,
str
,
Dict
[
str
,
Any
]],
Union
[
nn
.
Module
,
bool
,
Tuple
[
nn
.
Module
,
bool
]]]
def
traverse_and_mutate_submodules
(
root_module
:
nn
.
Module
,
hooks
:
List
[
MutationHook
],
mutate_kwargs
:
Dict
[
str
,
Any
],
topdown
:
bool
=
True
)
->
List
[
BaseSuperNetModule
]:
"""
"""
Replace xxxChoice in user's model with NAS modules
.
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node
.
Parameters
Parameters
----------
----------
root_module : nn.Module
root_module : nn.Module
User-defined module with xxxChoice in it. In fact, since this method is called in the ``__init__`` of
User-defined model space.
``BaseOneShotLightningModule``, this will be a pl.LightningModule.
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]]
it's usually a ``pytorch_lightning.LightningModule``.
Functions to replace xxxChoice modules. Keys should be xxxChoice type and values should be a
The mutation will be in-place on ``root_module``.
function that return an nn.module.
hooks : List[MutationHook]
modules : List[nn.Module]
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
The replace result. This is also the return value of this function.
When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict
Extra keyword arguments passed to hooks.
topdown : bool, default = False
If topdown is true, hooks are first called, before traversing its sub-module (i.e., pre-order DFS).
Otherwise, sub-modules are first traversed, before calling hooks on this node (i.e., post-order DFS).
Returns
Returns
----------
----------
modules :
List[
nn.Module]
modules :
Dict[str,
nn.Module]
The replace result.
The replace result.
"""
"""
if
modules
is
None
:
memo
=
{}
modules
=
[]
module_list
=
[]
def
apply
(
m
):
def
apply
(
m
):
for
name
,
child
in
m
.
named_children
():
for
name
,
child
in
m
.
named_children
():
child_type
=
type
(
child
)
# post-order DFS
if
child_type
in
replace_dict
.
keys
():
if
not
topdown
:
setattr
(
m
,
name
,
replace_dict
[
child_type
](
child
))
apply
(
child
)
modules
.
append
((
child
.
key
,
getattr
(
m
,
name
)))
mutate_result
=
None
for
hook
in
hooks
:
hook_suggest
=
hook
(
child
,
name
,
memo
,
mutate_kwargs
)
# parse the mutate result
if
isinstance
(
hook_suggest
,
tuple
):
hook_suggest
,
suppress
=
hook_suggest
elif
hook_suggest
is
True
:
hook_suggest
,
suppress
=
None
,
True
elif
not
hook_suggest
:
# none / false
hook_suggest
,
suppress
=
None
,
False
elif
isinstance
(
hook_suggest
,
nn
.
Module
):
suppress
=
True
else
:
else
:
raise
TypeError
(
f
'Mutation hook returned
{
hook_suggest
}
of unsupported type:
{
type
(
hook_suggest
)
}
.'
)
if
hook_suggest
is
not
None
:
if
not
isinstance
(
hook_suggest
,
BaseSuperNetModule
):
warnings
.
warn
(
"Mutation hook didn't return a BaseSuperNetModule. It will be ignored in hooked module list."
,
RuntimeWarning
)
setattr
(
m
,
name
,
hook_suggest
)
mutate_result
=
hook_suggest
# if suppress, no further mutation hooks are called
if
suppress
:
break
if
isinstance
(
mutate_result
,
BaseSuperNetModule
):
module_list
.
append
(
mutate_result
)
# pre-order DFS
if
topdown
:
apply
(
child
)
apply
(
child
)
apply
(
root_module
)
apply
(
root_module
)
return
modules
return
module_list
def
no_default_hook
(
module
:
nn
.
Module
,
name
:
str
,
memo
:
Dict
[
str
,
Any
],
mutate_kwargs
:
Dict
[
str
,
Any
])
->
bool
:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet
primitive_list
=
(
nas_nn
.
LayerChoice
,
nas_nn
.
InputChoice
,
nas_nn
.
ValueChoice
,
nas_nn
.
Repeat
,
nas_nn
.
NasBench101Cell
,
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
)
if
isinstance
(
module
,
primitive_list
):
raise
TypeError
(
f
'
{
type
(
module
).
__name__
}
is not supported'
)
if
isinstance
(
module
,
nas_nn
.
Cell
)
and
module
.
merge_op
!=
'all'
:
# need output_node_indices, which depends on super-net
raise
TypeError
(
f
'Cell with merge_op `
{
module
.
merge_op
}
` is not supported'
)
if
is_traceable
(
module
):
# check whether there is a value-choice in its arguments
has_valuechoice
=
False
for
arg
in
chain
(
module
.
trace_args
,
module
.
trace_kwargs
.
values
()):
if
isinstance
(
arg
,
ValueChoiceX
):
has_valuechoice
=
True
break
if
has_valuechoice
:
raise
TypeError
(
f
'`basic_unit`
{
type
(
module
).
__name__
}
with value choice in its arguments is not supported. '
'Please try to remove `basic_unit` to see if that works, or support this type with value choice manually.'
)
return
True
# suppress all other hooks
class
BaseOneShotLightningModule
(
pl
.
LightningModule
):
class
BaseOneShotLightningModule
(
pl
.
LightningModule
):
_custom_replace_dict_note
=
"""custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
_mutation_hooks_note
=
"""mutation_hooks : List[MutationHook]
The custom xxxChoice replace method. Keys should be ``xxxChoice`` type.
Mutation hooks are callable that inputs an Module and returns a :class:`BaseSuperNetModule`.
Values should callable accepting an ``nn.Module`` and returning an ``nn.Module``.
They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules.
This custom replace dict will override the default replace dict of each NAS method.
For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
stored in ``nas_modules``, and be the focus of the NAS algorithm.
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are three arguments:
#. a module that might be processed,
#. name of the module in its parent module,
#. a memo dict whose usage depends on the particular algorithm.
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.
The returned arguments can be also one of the three kinds:
#. tuple of: :class:`BaseSuperNetModule` or None, and boolean,
#. boolean,
#. :class:`BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the folliwng hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false.
If a none value appears on the place of :class:`BaseSuperNetModule`, it means the hook suggests to
keep the module unchanged, and nothing will happen.
"""
"""
_inner_module_note
=
"""inner_module : pytorch_lightning.LightningModule
_inner_module_note
=
"""inner_module : pytorch_lightning.LightningModule
...
@@ -79,30 +193,76 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -79,30 +193,76 @@ class BaseOneShotLightningModule(pl.LightningModule):
Attributes
Attributes
----------
----------
nas_modules : List[nn.Module]
nas_modules : List[BaseSuperNetModule]
The replace result of a specific NAS method.
Modules that have been mutated, which the search algorithms should care about.
xxxChoice will be replaced with some other modules with respect to the NAS method.
Parameters
Parameters
----------
----------
"""
+
_inner_module_note
+
_
custom_replace_dict
_note
"""
+
_inner_module_note
+
_
mutation_hooks
_note
automatic_optimization
=
False
automatic_optimization
=
False
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
custom_replace_dict
:
Optional
[
ReplaceDictType
]
=
None
):
def
default_mutation_hooks
(
self
)
->
List
[
MutationHook
]:
"""Override this to define class-default mutation hooks."""
return
[
no_default_hook
]
def
mutate_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return
{}
def
__init__
(
self
,
base_model
:
pl
.
LightningModule
,
mutation_hooks
:
List
[
MutationHook
]
=
None
):
super
().
__init__
()
super
().
__init__
()
assert
isinstance
(
inner_module
,
pl
.
LightningModule
)
assert
isinstance
(
base_model
,
pl
.
LightningModule
)
self
.
model
=
inner_module
self
.
model
=
base_model
# replace xxxChoice with respect to NAS alg
# append the default hooks
# replaced modules are stored in self.nas_modules
mutation_hooks
=
(
mutation_hooks
or
[])
+
self
.
default_mutation_hooks
()
self
.
nas_modules
=
[]
choice_replace_dict
=
self
.
default_replace_dict
# traverse the model, calling hooks on every submodule
if
custom_replace_dict
is
not
None
:
self
.
nas_modules
:
List
[
BaseSuperNetModule
]
=
traverse_and_mutate_submodules
(
for
k
,
v
in
custom_replace_dict
.
items
():
self
.
model
,
mutation_hooks
,
self
.
mutate_kwargs
(),
topdown
=
True
)
assert
isinstance
(
v
,
nn
.
Module
)
choice_replace_dict
[
k
]
=
v
def
search_space_spec
(
self
)
->
Dict
[
str
,
ParameterSpec
]:
_replace_module_with_type
(
self
.
model
,
choice_replace_dict
,
self
.
nas_modules
)
"""Get the search space specification from ``nas_module``.
Returns
-------
dict
Key is the name of the choice, value is the corresponding :class:`ParameterSpec`.
"""
result
=
{}
for
module
in
self
.
nas_modules
:
result
.
update
(
module
.
search_space_spec
())
return
result
def
resample
(
self
)
->
Dict
[
str
,
Any
]:
"""Trigger the resample for each ``nas_module``.
Sometimes (e.g., in differentiable cases), it does nothing.
Returns
-------
dict
Sampled architecture.
"""
result
=
{}
for
module
in
self
.
nas_modules
:
result
.
update
(
module
.
resample
(
memo
=
result
))
return
result
def
export
(
self
)
->
Dict
[
str
,
Any
]:
"""
Export the NAS result, ideally the best choice of each ``nas_module``.
You may implement an ``export`` method for your customized ``nas_module``.
Returns
--------
dict
Keys are names of ``nas_modules``, and values are the choice indices of them.
"""
result
=
{}
for
module
in
self
.
nas_modules
:
result
.
update
(
module
.
export
(
memo
=
result
))
return
result
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
return
self
.
model
(
x
)
...
@@ -148,6 +308,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -148,6 +308,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
return
arc_optimizers
+
w_optimizers
,
lr_schedulers
return
arc_optimizers
+
w_optimizers
,
lr_schedulers
def
on_train_start
(
self
):
def
on_train_start
(
self
):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
self
.
model
.
trainer
=
self
.
trainer
self
.
model
.
log
=
self
.
log
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_train_start
()
return
self
.
model
.
on_train_start
()
...
@@ -161,10 +324,10 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -161,10 +324,10 @@ class BaseOneShotLightningModule(pl.LightningModule):
def
on_fit_end
(
self
):
def
on_fit_end
(
self
):
return
self
.
model
.
on_train_end
()
return
self
.
model
.
on_train_end
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
unused
=
0
):
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
unused
)
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
unused
)
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
unused
=
0
):
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
unused
)
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
unused
)
def
on_epoch_start
(
self
):
def
on_epoch_start
(
self
):
...
@@ -185,7 +348,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -185,7 +348,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
def
on_after_backward
(
self
):
def
on_after_backward
(
self
):
return
self
.
model
.
on_after_backward
()
return
self
.
model
.
on_after_backward
()
def
configure_gradient_clipping
(
self
,
optimizer
,
optimizer_idx
,
gradient_clip_val
=
None
,
gradient_clip_algorithm
=
None
):
def
configure_gradient_clipping
(
self
,
optimizer
,
optimizer_idx
,
gradient_clip_val
=
None
,
gradient_clip_algorithm
=
None
):
return
self
.
model
.
configure_gradient_clipping
(
optimizer
,
optimizer_idx
,
gradient_clip_val
,
gradient_clip_algorithm
)
return
self
.
model
.
configure_gradient_clipping
(
optimizer
,
optimizer_idx
,
gradient_clip_val
,
gradient_clip_algorithm
)
def
configure_architecture_optimizers
(
self
):
def
configure_architecture_optimizers
(
self
):
...
@@ -200,20 +363,6 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -200,20 +363,6 @@ class BaseOneShotLightningModule(pl.LightningModule):
"""
"""
return
None
return
None
@
property
def
default_replace_dict
(
self
):
"""
Default ``xxxChoice`` replace dict. This is called in ``__init__`` to get the default replace functions for your NAS algorithm.
Note that your default replace functions may be overridden by user-defined ``custom_replace_dict``.
Returns
----------
replace_dict : Dict[Type, Callable[nn.Module, nn.Module]]
Same as ``custom_replace_dict`` in ``__init__``, but this will be overridden if users define their own replace functions.
"""
replace_dict
=
{}
return
replace_dict
def
call_lr_schedulers
(
self
,
batch_index
):
def
call_lr_schedulers
(
self
,
batch_index
):
"""
"""
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
...
@@ -254,13 +403,13 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -254,13 +403,13 @@ class BaseOneShotLightningModule(pl.LightningModule):
def
call_user_optimizers
(
self
,
method
):
def
call_user_optimizers
(
self
,
method
):
"""
"""
Function that imitates lightning trainer's behavio
u
r of calling user's optimizers. Since auto_optimization is turned off by this
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Parameters
Parameters
----------
----------
method : str
method : str
Method to call. Only
'
step
'
and
'
zero_grad
'
are supported now.
Method to call. Only
``
step
``
and
``
zero_grad
``
are supported now.
"""
"""
def
apply_method
(
optimizer
,
method
):
def
apply_method
(
optimizer
,
method
):
if
method
==
'step'
:
if
method
==
'step'
:
...
@@ -296,7 +445,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -296,7 +445,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
architecture optimizers.
architecture optimizers.
"""
"""
opts
=
self
.
optimizers
()
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
# pylint: disable=unsubscriptable-object
arc_opts
=
opts
[:
self
.
arc_optim_count
]
arc_opts
=
opts
[:
self
.
arc_optim_count
]
if
len
(
arc_opts
)
==
1
:
if
len
(
arc_opts
)
==
1
:
...
@@ -310,7 +459,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -310,7 +459,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
@
property
@
property
def
user_optimizers
(
self
):
def
user_optimizers
(
self
):
"""
"""
Get user optimizers from all optimizers. Use this to get user optimizers in ``training
step``.
Get user optimizers from all optimizers. Use this to get user optimizers in ``training
_
step``.
Returns
Returns
----------
----------
...
@@ -318,26 +467,10 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -318,26 +467,10 @@ class BaseOneShotLightningModule(pl.LightningModule):
Optimizers defined by user's model. This will be None if there is no user optimizers.
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
"""
opts
=
self
.
optimizers
()
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
# pylint: disable=unsubscriptable-object
return
opts
[
self
.
arc_optim_count
:]
return
opts
[
self
.
arc_optim_count
:]
# If there is only 1 optimizer and no architecture optimizer
# If there is only 1 optimizer and no architecture optimizer
if
self
.
arc_optim_count
==
0
:
if
self
.
arc_optim_count
==
0
:
return
opts
return
opts
return
None
return
None
def
export
(
self
):
"""
Export the NAS result, ideally the best choice of each nas_modules.
You may implement an ``export`` method for your customized nas_module.
Returns
--------
result : Dict[str, int]
Keys are names of nas_modules, and values are the choice indices of them.
"""
result
=
{}
for
name
,
module
in
self
.
nas_modules
:
if
name
not
in
result
:
result
[
name
]
=
module
.
export
()
return
result
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
collections
import
OrderedDict
"""Experimental version of differentiable one-shot implementation."""
from
typing
import
Optional
from
typing
import
List
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
.base_lightning
import
BaseOneShotLightningModule
,
ReplaceDictType
class
DartsLayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
layer_choice
):
super
(
DartsLayerChoice
,
self
).
__init__
()
self
.
name
=
layer_choice
.
label
self
.
op_choices
=
nn
.
ModuleDict
(
OrderedDict
([(
name
,
layer_choice
[
name
])
for
name
in
layer_choice
.
names
]))
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
op_choices
))
*
1e-3
)
def
forward
(
self
,
*
args
,
**
kwargs
):
op_results
=
torch
.
stack
([
op
(
*
args
,
**
kwargs
)
for
op
in
self
.
op_choices
.
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
return
torch
.
sum
(
op_results
*
F
.
softmax
(
self
.
alpha
,
-
1
).
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
yield
p
def
named_parameters
(
self
,
recurse
=
False
):
for
name
,
p
in
super
(
DartsLayerChoice
,
self
).
named_parameters
():
if
name
==
'alpha'
:
continue
yield
name
,
p
def
export
(
self
):
return
list
(
self
.
op_choices
.
keys
())[
torch
.
argmax
(
self
.
alpha
).
item
()]
class
DartsInputChoice
(
nn
.
Module
):
def
__init__
(
self
,
input_choice
):
super
(
DartsInputChoice
,
self
).
__init__
()
self
.
name
=
input_choice
.
label
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1e-3
)
self
.
n_chosen
=
input_choice
.
n_chosen
or
1
def
forward
(
self
,
inputs
):
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
return
torch
.
sum
(
inputs
*
F
.
softmax
(
self
.
alpha
,
-
1
).
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
yield
p
def
named_parameters
(
self
,
recurse
=
False
):
for
name
,
p
in
super
(
DartsInputChoice
,
self
).
named_parameters
():
if
name
==
'alpha'
:
continue
yield
name
,
p
def
export
(
self
):
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
return
torch
.
argsort
(
-
self
.
alpha
).
cpu
().
numpy
().
tolist
()[:
self
.
n_chosen
]
from
.supermodule.differentiable
import
(
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
MixedOpDifferentiablePolicy
,
GumbelSoftmax
)
from
.supermodule.proxyless
import
ProxylessMixedInput
,
ProxylessMixedLayer
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
class
DartsModule
(
BaseOneShotLightningModule
):
class
Darts
Lightning
Module
(
BaseOneShotLightningModule
):
_darts_note
=
"""
_darts_note
=
"""
DARTS :cite:p:`liu2018darts` algorithm is one of the most fundamental one-shot algorithm.
DARTS :cite:p:`liu2018darts` algorithm is one of the most fundamental one-shot algorithm.
...
@@ -74,6 +26,10 @@ class DartsModule(BaseOneShotLightningModule):
...
@@ -74,6 +26,10 @@ class DartsModule(BaseOneShotLightningModule):
The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet.
The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet.
*New in v2.8*: Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
One difference is that, in DARTS, we are using Softmax instead of GumbelSoftmax.
{{module_notes}}
{{module_notes}}
Parameters
Parameters
...
@@ -82,18 +38,34 @@ class DartsModule(BaseOneShotLightningModule):
...
@@ -82,18 +38,34 @@ class DartsModule(BaseOneShotLightningModule):
{base_params}
{base_params}
arc_learning_rate : float
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_
custom_replace_dict
_note
)
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_
mutation_hooks
_note
)
__doc__
=
_darts_note
.
format
(
__doc__
=
_darts_note
.
format
(
module_notes
=
'The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.'
,
module_notes
=
'The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
)
def
default_mutation_hooks
(
self
)
->
List
[
MutationHook
]:
"""Replace modules with differentiable versions"""
hooks
=
[
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedInput
.
mutate
,
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
.
append
(
no_default_hook
)
return
hooks
def
mutate_kwargs
(
self
):
"""Use differentiable strategy for mixed operations."""
return
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
}
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
custom_replace_dict
:
Optional
[
ReplaceDictType
]
=
None
,
mutation_hooks
:
List
[
MutationHook
]
=
None
,
arc_learning_rate
:
float
=
3.0E-4
):
arc_learning_rate
:
float
=
3.0E-4
):
super
().
__init__
(
inner_module
,
custom_replace_dict
=
custom_replace_dict
)
self
.
arc_learning_rate
=
arc_learning_rate
self
.
arc_learning_rate
=
arc_learning_rate
super
().
__init__
(
inner_module
,
mutation_hooks
=
mutation_hooks
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
# grad manually
# grad manually
...
@@ -105,7 +77,7 @@ class DartsModule(BaseOneShotLightningModule):
...
@@ -105,7 +77,7 @@ class DartsModule(BaseOneShotLightningModule):
# phase 1: architecture step
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
# The _resample hook is kept for some darts-based NAS methods like proxyless.
# See code of those methods for details.
# See code of those methods for details.
self
.
_
resample
()
self
.
resample
()
arc_optim
.
zero_grad
()
arc_optim
.
zero_grad
()
arc_step_loss
=
self
.
model
.
training_step
(
val_batch
,
2
*
batch_idx
)
arc_step_loss
=
self
.
model
.
training_step
(
val_batch
,
2
*
batch_idx
)
if
isinstance
(
arc_step_loss
,
dict
):
if
isinstance
(
arc_step_loss
,
dict
):
...
@@ -115,7 +87,7 @@ class DartsModule(BaseOneShotLightningModule):
...
@@ -115,7 +87,7 @@ class DartsModule(BaseOneShotLightningModule):
arc_optim
.
step
()
arc_optim
.
step
()
# phase 2: model step
# phase 2: model step
self
.
_
resample
()
self
.
resample
()
self
.
call_user_optimizers
(
'zero_grad'
)
self
.
call_user_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
...
@@ -127,178 +99,21 @@ class DartsModule(BaseOneShotLightningModule):
...
@@ -127,178 +99,21 @@ class DartsModule(BaseOneShotLightningModule):
return
loss_and_metrics
return
loss_and_metrics
def
_resample
(
self
):
# Note: This hook is kept for following darts-based NAS algs.
pass
def
finalize_grad
(
self
):
def
finalize_grad
(
self
):
# Note: This hook is currently kept for Proxyless NAS.
# Note: This hook is currently kept for Proxyless NAS.
pass
pass
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
DartsLayerChoice
,
InputChoice
:
DartsInputChoice
}
def
configure_architecture_optimizers
(
self
):
def
configure_architecture_optimizers
(
self
):
# The alpha in DartsXXXChoices is the architecture parameter of DARTS. All alphas share one optimizer.
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params
=
{}
ctrl_params
=
[]
for
_
,
m
in
self
.
nas_modules
:
for
m
in
self
.
nas_modules
:
if
m
.
name
in
ctrl_params
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
assert
m
.
alpha
.
size
()
==
ctrl_params
[
m
.
name
].
size
(),
'Size of parameters with the same label should be same.'
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
set
(
ctrl_params
)),
3.e-4
,
betas
=
(
0.5
,
0.999
),
m
.
alpha
=
ctrl_params
[
m
.
name
]
else
:
ctrl_params
[
m
.
name
]
=
m
.
alpha
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
ctrl_params
.
values
()),
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
weight_decay
=
1.0E-3
)
return
ctrl_optim
return
ctrl_optim
class
_ArchGradientFunction
(
torch
.
autograd
.
Function
):
class
ProxylessLightningModule
(
DartsLightningModule
):
@
staticmethod
def
forward
(
ctx
,
x
,
binary_gates
,
run_func
,
backward_func
):
ctx
.
run_func
=
run_func
ctx
.
backward_func
=
backward_func
detached_x
=
x
.
detach
()
detached_x
.
requires_grad
=
x
.
requires_grad
with
torch
.
enable_grad
():
output
=
run_func
(
detached_x
)
ctx
.
save_for_backward
(
detached_x
,
output
)
return
output
.
data
@
staticmethod
def
backward
(
ctx
,
grad_output
):
detached_x
,
output
=
ctx
.
saved_tensors
grad_x
=
torch
.
autograd
.
grad
(
output
,
detached_x
,
grad_output
,
only_inputs
=
True
)
# compute gradients w.r.t. binary_gates
binary_grads
=
ctx
.
backward_func
(
detached_x
.
data
,
output
.
data
,
grad_output
.
data
)
return
grad_x
[
0
],
binary_grads
,
None
,
None
class
ProxylessLayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
ops
):
super
(
ProxylessLayerChoice
,
self
).
__init__
()
self
.
ops
=
nn
.
ModuleList
(
ops
)
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
ops
))
*
1E-3
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
self
.
ops
))
*
1E-3
)
self
.
sampled
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
def
run_function
(
ops
,
active_id
,
**
kwargs
):
def
forward
(
_x
):
return
ops
[
active_id
](
_x
,
**
kwargs
)
return
forward
def
backward_function
(
ops
,
active_id
,
binary_gates
,
**
kwargs
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
len
(
ops
)):
if
k
!=
active_id
:
out_k
=
ops
[
k
](
_x
.
data
,
**
kwargs
)
else
:
out_k
=
_output
.
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
assert
len
(
args
)
==
1
x
=
args
[
0
]
return
_ArchGradientFunction
.
apply
(
x
,
self
.
_binary_gates
,
run_function
(
self
.
ops
,
self
.
sampled
,
**
kwargs
),
backward_function
(
self
.
ops
,
self
.
sampled
,
self
.
_binary_gates
,
**
kwargs
)
)
return
super
().
forward
(
*
args
,
**
kwargs
)
def
resample
(
self
):
probs
=
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
sampled
=
sample
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
sample
]
=
1.0
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
with
torch
.
no_grad
():
if
self
.
alpha
.
grad
is
None
:
self
.
alpha
.
grad
=
torch
.
zeros_like
(
self
.
alpha
.
data
)
probs
=
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
for
i
in
range
(
len
(
self
.
ops
)):
for
j
in
range
(
len
(
self
.
ops
)):
self
.
alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
def
export
(
self
):
return
torch
.
argmax
(
self
.
alpha
).
item
()
def
export_prob
(
self
):
return
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
class
ProxylessInputChoice
(
nn
.
Module
):
def
__init__
(
self
,
input_choice
):
super
().
__init__
()
self
.
num_input_candidates
=
input_choice
.
n_candidates
self
.
alpha
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1E-3
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
input_choice
.
n_candidates
)
*
1E-3
)
self
.
sampled
=
None
def
forward
(
self
,
inputs
):
if
self
.
training
:
def
run_function
(
active_sample
):
return
lambda
x
:
x
[
active_sample
]
def
backward_function
(
binary_gates
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
self
.
num_input_candidates
):
out_k
=
_x
[
k
].
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
inputs
=
torch
.
stack
(
inputs
,
0
)
return
_ArchGradientFunction
.
apply
(
inputs
,
self
.
_binary_gates
,
run_function
(
self
.
sampled
),
backward_function
(
self
.
_binary_gates
)
)
return
super
().
forward
(
inputs
)
def
resample
(
self
,
sample
=
None
):
if
sample
is
None
:
probs
=
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
sampled
=
sample
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
sample
]
=
1.0
return
self
.
sampled
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
with
torch
.
no_grad
():
if
self
.
alpha
.
grad
is
None
:
self
.
alpha
.
grad
=
torch
.
zeros_like
(
self
.
alpha
.
data
)
probs
=
F
.
softmax
(
self
.
alpha
,
dim
=-
1
)
for
i
in
range
(
self
.
num_input_candidates
):
for
j
in
range
(
self
.
num_input_candidates
):
self
.
alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
class
ProxylessModule
(
DartsModule
):
_proxyless_note
=
"""
_proxyless_note
=
"""
Implementation of ProxylessNAS :cite:p:`cai2018proxylessnas`.
Implementation of ProxylessNAS :cite:p:`cai2018proxylessnas`.
It's a DARTS-based method that resamples the architecture to reduce memory consumption.
It's a DARTS-based method that resamples the architecture to reduce memory consumption.
...
@@ -313,54 +128,38 @@ class ProxylessModule(DartsModule):
...
@@ -313,54 +128,38 @@ class ProxylessModule(DartsModule):
{base_params}
{base_params}
arc_learning_rate : float
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_
custom_replace_dict
_note
)
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_
mutation_hooks
_note
)
__doc__
=
_proxyless_note
.
format
(
__doc__
=
_proxyless_note
.
format
(
module_notes
=
'This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.'
,
module_notes
=
'This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
)
@
property
def
default_mutation_hooks
(
self
)
->
List
[
MutationHook
]:
def
default_replace_dict
(
self
):
"""Replace modules with gumbel-differentiable versions"""
return
{
hooks
=
[
LayerChoice
:
ProxylessLayerChoice
,
ProxylessMixedLayer
.
mutate
,
InputChoice
:
ProxylessInputChoice
ProxylessMixedInput
.
mutate
,
}
no_default_hook
,
]
def
_resample
(
self
):
# FIXME: no support for mixed operation currently
for
_
,
m
in
self
.
nas_modules
:
return
hooks
m
.
resample
()
def
finalize_grad
(
self
):
def
finalize_grad
(
self
):
for
_
,
m
in
self
.
nas_modules
:
for
m
in
self
.
nas_modules
:
m
.
finalize_grad
()
m
.
finalize_grad
()
class
SNASLayerChoice
(
DartsLayerChoice
):
class
GumbelDartsLightningModule
(
DartsLightningModule
):
def
forward
(
self
,
*
args
,
**
kwargs
):
_gumbel_darts_note
=
"""
one_hot
=
F
.
gumbel_softmax
(
self
.
alpha
,
self
.
temp
)
op_results
=
torch
.
stack
([
op
(
*
args
,
**
kwargs
)
for
op
in
self
.
op_choices
.
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
yhat
=
torch
.
sum
(
op_results
*
one_hot
.
view
(
*
alpha_shape
),
0
)
return
yhat
class
SNASInputChoice
(
DartsInputChoice
):
def
forward
(
self
,
inputs
):
one_hot
=
F
.
gumbel_softmax
(
self
.
alpha
,
self
.
temp
)
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
yhat
=
torch
.
sum
(
inputs
*
one_hot
.
view
(
*
alpha_shape
),
0
)
return
yhat
class
SnasModule
(
DartsModule
):
_snas_note
=
"""
Implementation of SNAS :cite:p:`xie2018snas`.
Implementation of SNAS :cite:p:`xie2018snas`.
It's a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
It's a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it samples one path on forward,
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
and implements its own backward to update the architecture parameters based on only one path.
*New in v2.8*: Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
{{module_notes}}
{{module_notes}}
Parameters
Parameters
...
@@ -376,20 +175,32 @@ class SnasModule(DartsModule):
...
@@ -376,20 +175,32 @@ class SnasModule(DartsModule):
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_custom_replace_dict_note
)
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
__doc__
=
_snas_note
.
format
(
def
default_mutation_hooks
(
self
)
->
List
[
MutationHook
]:
module_notes
=
'This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.'
,
"""Replace modules with gumbel-differentiable versions"""
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
hooks
=
[
)
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedInput
.
mutate
,
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
.
append
(
no_default_hook
)
return
hooks
def
mutate_kwargs
(
self
):
"""Use gumbel softmax."""
return
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
,
'softmax'
:
GumbelSoftmax
(),
}
def
__init__
(
self
,
inner_module
,
def
__init__
(
self
,
inner_module
,
custom_replace_dict
:
Optional
[
ReplaceDictType
]
=
None
,
mutation_hooks
:
List
[
MutationHook
]
=
None
,
arc_learning_rate
:
float
=
3.0e-4
,
arc_learning_rate
:
float
=
3.0e-4
,
gumbel_temperature
:
float
=
1.
,
gumbel_temperature
:
float
=
1.
,
use_temp_anneal
:
bool
=
False
,
use_temp_anneal
:
bool
=
False
,
min_temp
:
float
=
.
33
):
min_temp
:
float
=
.
33
):
super
().
__init__
(
inner_module
,
custom_replace_dict
,
arc_learning_rate
=
arc_learning_rate
)
super
().
__init__
(
inner_module
,
mutation_hooks
,
arc_learning_rate
=
arc_learning_rate
)
self
.
temp
=
gumbel_temperature
self
.
temp
=
gumbel_temperature
self
.
init_temp
=
gumbel_temperature
self
.
init_temp
=
gumbel_temperature
self
.
use_temp_anneal
=
use_temp_anneal
self
.
use_temp_anneal
=
use_temp_anneal
...
@@ -400,14 +211,7 @@ class SnasModule(DartsModule):
...
@@ -400,14 +211,7 @@ class SnasModule(DartsModule):
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
for
_
,
nas_
module
in
self
.
nas_modules
:
for
module
in
self
.
nas_modules
:
nas_
module
.
temp
=
self
.
temp
module
.
_softmax
.
temp
=
self
.
temp
return
self
.
model
.
on_epoch_start
()
return
self
.
model
.
on_epoch_start
()
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
SNASLayerChoice
,
InputChoice
:
SNASInputChoice
}
nni/retiarii/oneshot/pytorch/sampling.py
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Dict
,
Any
,
Optional
"""Experimental version of sampling-based one-shot implementation."""
from
typing
import
Dict
,
Any
,
List
import
random
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
nni.retiarii.nn.pytorch.api
import
LayerChoice
,
InputChoice
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.
random
import
PathSamplingLayerChoice
,
PathSampling
InputCho
ic
e
from
.
supermodule.sampling
import
PathSamplingInput
,
PathSamplingLayer
,
MixedOp
PathSampling
Pol
ic
y
from
.
base_lightning
import
BaseOneShotLightningModule
,
ReplaceDictType
from
.
supermodule.operation
import
NATIVE_MIXED_OPERATIONS
from
.enas
import
ReinforceController
,
ReinforceField
from
.enas
import
ReinforceController
,
ReinforceField
class
EnasModule
(
BaseOneShotLightningModule
):
class
RandomSamplingLightningModule
(
BaseOneShotLightningModule
):
_random_note
=
"""
Random Sampling NAS Algorithm.
In each epoch, model parameters are trained after a uniformly random sampling of each choice.
Notably, the exporting result is **also a random sample** of the search space.
Parameters
----------
{{module_params}}
{base_params}
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
__doc__
=
_random_note
.
format
(
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
# turn on automatic optimization because nothing interesting is going on here.
automatic_optimization
=
True
def
default_mutation_hooks
(
self
)
->
List
[
MutationHook
]:
"""Replace modules with differentiable versions"""
hooks
=
[
PathSamplingLayer
.
mutate
,
PathSamplingInput
.
mutate
,
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
.
append
(
no_default_hook
)
return
hooks
def
mutate_kwargs
(
self
):
"""Use path sampling strategy for mixed-operations."""
return
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
}
def
training_step
(
self
,
batch
,
batch_idx
):
self
.
resample
()
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
class
EnasLightningModule
(
RandomSamplingLightningModule
):
_enas_note
=
"""
_enas_note
=
"""
The implementation of ENAS :cite:p:`pham2018efficient`. There are 2 steps in an epoch.
The implementation of ENAS :cite:p:`pham2018efficient`. There are 2 steps in an epoch.
Firstly, training model parameters.
Firstly, training model parameters.
...
@@ -39,27 +80,34 @@ class EnasModule(BaseOneShotLightningModule):
...
@@ -39,27 +80,34 @@ class EnasModule(BaseOneShotLightningModule):
Number of steps that will be aggregated into one mini-batch for RL controller.
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_grad_clip : float
ctrl_grad_clip : float
Gradient clipping value of controller.
Gradient clipping value of controller.
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_
custom_replace_dict
_note
)
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_
mutation_hooks
_note
)
__doc__
=
_enas_note
.
format
(
__doc__
=
_enas_note
.
format
(
module_notes
=
'``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.'
,
module_notes
=
'``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
)
automatic_optimization
=
False
def
__init__
(
self
,
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
inner_module
:
pl
.
LightningModule
,
*
,
ctrl_kwargs
:
Dict
[
str
,
Any
]
=
None
,
ctrl_kwargs
:
Dict
[
str
,
Any
]
=
None
,
entropy_weight
:
float
=
1e-4
,
entropy_weight
:
float
=
1e-4
,
skip_weight
:
float
=
.
8
,
skip_weight
:
float
=
.
8
,
baseline_decay
:
float
=
.
999
,
baseline_decay
:
float
=
.
999
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_grad_clip
:
float
=
0
,
ctrl_grad_clip
:
float
=
0
,
custom_replace_dict
:
Optional
[
ReplaceDictType
]
=
None
):
mutation_hooks
:
List
[
MutationHook
]
=
None
):
super
().
__init__
(
inner_module
,
custom_replace_dict
)
super
().
__init__
(
inner_module
,
mutation_hooks
)
self
.
nas_fields
=
[
ReinforceField
(
name
,
len
(
module
),
# convert parameter spec to legacy ReinforceField
isinstance
(
module
,
PathSamplingLayerChoice
)
or
module
.
n_chosen
==
1
)
# this part will be refactored
for
name
,
module
in
self
.
nas_modules
]
self
.
nas_fields
:
List
[
ReinforceField
]
=
[]
for
name
,
param_spec
in
self
.
search_space_spec
().
items
():
if
param_spec
.
chosen_size
not
in
(
1
,
None
):
raise
ValueError
(
'ENAS does not support n_chosen to be values other than 1 or None.'
)
self
.
nas_fields
.
append
(
ReinforceField
(
name
,
param_spec
.
size
,
param_spec
.
chosen_size
==
1
))
self
.
controller
=
ReinforceController
(
self
.
nas_fields
,
**
(
ctrl_kwargs
or
{}))
self
.
controller
=
ReinforceController
(
self
.
nas_fields
,
**
(
ctrl_kwargs
or
{}))
self
.
entropy_weight
=
entropy_weight
self
.
entropy_weight
=
entropy_weight
...
@@ -72,20 +120,13 @@ class EnasModule(BaseOneShotLightningModule):
...
@@ -72,20 +120,13 @@ class EnasModule(BaseOneShotLightningModule):
def
configure_architecture_optimizers
(
self
):
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
PathSamplingLayerChoice
,
InputChoice
:
PathSamplingInputChoice
}
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from.
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from.
batch
,
source
=
batch
batch
,
source
=
batch
if
source
==
'train'
:
if
source
==
'train'
:
# step 1: train model params
# step 1: train model params
self
.
_
resample
()
self
.
resample
()
self
.
call_user_optimizers
(
'zero_grad'
)
self
.
call_user_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
loss_and_metrics
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
...
@@ -99,7 +140,7 @@ class EnasModule(BaseOneShotLightningModule):
...
@@ -99,7 +140,7 @@ class EnasModule(BaseOneShotLightningModule):
x
,
y
=
batch
x
,
y
=
batch
arc_opt
=
self
.
architecture_optimizers
arc_opt
=
self
.
architecture_optimizers
arc_opt
.
zero_grad
()
arc_opt
.
zero_grad
()
self
.
_
resample
()
self
.
resample
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
)
# use the default metric of self.model as reward function
# use the default metric of self.model as reward function
...
@@ -107,7 +148,7 @@ class EnasModule(BaseOneShotLightningModule):
...
@@ -107,7 +148,7 @@ class EnasModule(BaseOneShotLightningModule):
_
,
metric
=
next
(
iter
(
self
.
model
.
metrics
.
items
()))
_
,
metric
=
next
(
iter
(
self
.
model
.
metrics
.
items
()))
else
:
else
:
if
'default'
not
in
self
.
model
.
metrics
.
keys
():
if
'default'
not
in
self
.
model
.
metrics
.
keys
():
raise
KeyError
(
'model.metrics should contain a ``default`` key when'
\
raise
KeyError
(
'model.metrics should contain a ``default`` key when'
'there are multiple metrics'
)
'there are multiple metrics'
)
metric
=
self
.
model
.
metrics
[
'default'
]
metric
=
self
.
model
.
metrics
[
'default'
]
...
@@ -128,60 +169,23 @@ class EnasModule(BaseOneShotLightningModule):
...
@@ -128,60 +169,23 @@ class EnasModule(BaseOneShotLightningModule):
arc_opt
.
step
()
arc_opt
.
step
()
arc_opt
.
zero_grad
()
arc_opt
.
zero_grad
()
def
_
resample
(
self
):
def
resample
(
self
):
"""
"""
Resample the architecture with ENAS controller."""
Re
sample
the architecture as ENAS result. This doesn't require an ``export`` method in nas_modules to work.
sample
=
self
.
controller
.
resample
()
"""
result
=
self
.
_interpret_controller_sampling_result
(
sample
)
result
=
self
.
controller
.
resample
()
for
module
in
self
.
nas_modules
:
for
name
,
module
in
self
.
nas_modules
:
module
.
resample
(
memo
=
result
)
module
.
sampled
=
result
[
name
]
return
result
def
export
(
self
):
def
export
(
self
):
"""Run one more inference of ENAS controller."""
self
.
controller
.
eval
()
self
.
controller
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
controller
.
resample
()
return
self
.
_interpret_controller_sampling_result
(
self
.
controller
.
resample
())
def
_interpret_controller_sampling_result
(
self
,
sample
:
Dict
[
str
,
int
])
->
Dict
[
str
,
Any
]:
class
RandomSamplingModule
(
BaseOneShotLightningModule
):
"""Convert ``{label: index}`` to ``{label: name}``"""
_random_note
=
"""
space_spec
=
self
.
search_space_spec
()
Random Sampling NAS Algorithm.
for
key
in
list
(
sample
.
keys
()):
In each epoch, model parameters are trained after a uniformly random sampling of each choice.
sample
[
key
]
=
space_spec
[
key
].
values
[
sample
[
key
]]
Notably, the exporting result is **also a random sample** of the search space.
return
sample
Parameters
----------
{{module_params}}
{base_params}
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_custom_replace_dict_note
)
__doc__
=
_random_note
.
format
(
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
automatic_optimization
=
True
def
training_step
(
self
,
batch
,
batch_idx
):
self
.
_resample
()
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
@
property
def
default_replace_dict
(
self
):
return
{
LayerChoice
:
PathSamplingLayerChoice
,
InputChoice
:
PathSamplingInputChoice
}
def
_resample
(
self
):
"""
Resample the architecture as RandomSample result. This is simply a uniformly sampling that doesn't require an ``export``
method in nas_modules to work.
"""
result
=
{}
for
name
,
module
in
self
.
nas_modules
:
if
name
not
in
result
:
result
[
name
]
=
random
.
randint
(
0
,
len
(
module
)
-
1
)
module
.
sampled
=
result
[
name
]
return
result
def
export
(
self
):
return
self
.
_resample
()
nni/retiarii/oneshot/pytorch/strategy.py
View file @
14d2966b
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
This file is put here simply because it relies on "pytorch".
This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""
"""
import
warnings
import
warnings
...
@@ -19,8 +21,8 @@ from nni.retiarii.strategy.base import BaseStrategy
...
@@ -19,8 +21,8 @@ from nni.retiarii.strategy.base import BaseStrategy
from
nni.retiarii.evaluator.pytorch.lightning
import
Lightning
,
LightningModule
from
nni.retiarii.evaluator.pytorch.lightning
import
Lightning
,
LightningModule
from
.base_lightning
import
BaseOneShotLightningModule
from
.base_lightning
import
BaseOneShotLightningModule
from
.differentiable
import
DartsModule
,
Proxyless
Module
,
Snas
Module
from
.differentiable
import
Darts
Lightning
Module
,
Proxyless
LightningModule
,
GumbelDartsLightning
Module
from
.sampling
import
EnasModule
,
RandomSamplingModule
from
.sampling
import
Enas
Lightning
Module
,
RandomSampling
Lightning
Module
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
...
@@ -80,50 +82,50 @@ class OneShotStrategy(BaseStrategy):
...
@@ -80,50 +82,50 @@ class OneShotStrategy(BaseStrategy):
class
DARTS
(
OneShotStrategy
):
class
DARTS
(
OneShotStrategy
):
__doc__
=
DartsModule
.
_darts_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
__doc__
=
Darts
Lightning
Module
.
_darts_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
DartsModule
,
**
kwargs
)
super
().
__init__
(
Darts
Lightning
Module
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
class
Proxyless
(
OneShotStrategy
):
class
Proxyless
(
OneShotStrategy
):
__doc__
=
ProxylessModule
.
_proxyless_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
__doc__
=
Proxyless
Lightning
Module
.
_proxyless_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
Enas
Module
,
**
kwargs
)
super
().
__init__
(
ProxylessLightning
Module
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
class
SNA
S
(
OneShotStrategy
):
class
GumbelDART
S
(
OneShotStrategy
):
__doc__
=
SnasModule
.
_sna
s_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
__doc__
=
GumbelDartsLightningModule
.
_gumbel_dart
s_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
Snas
Module
,
**
kwargs
)
super
().
__init__
(
GumbelDartsLightning
Module
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
return
InterleavedTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
class
ENAS
(
OneShotStrategy
):
class
ENAS
(
OneShotStrategy
):
__doc__
=
EnasModule
.
_enas_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
__doc__
=
Enas
Lightning
Module
.
_enas_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
EnasModule
,
**
kwargs
)
super
().
__init__
(
Enas
Lightning
Module
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
ConcatenateTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
return
ConcatenateTrainValDataLoader
(
train_dataloader
,
val_dataloaders
)
class
RandomOneShot
(
OneShotStrategy
):
class
RandomOneShot
(
OneShotStrategy
):
__doc__
=
RandomSamplingModule
.
_random_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
__doc__
=
RandomSampling
Lightning
Module
.
_random_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
RandomSamplingModule
,
**
kwargs
)
super
().
__init__
(
RandomSampling
Lightning
Module
,
**
kwargs
)
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
def
_get_dataloader
(
self
,
train_dataloader
,
val_dataloaders
):
return
train_dataloader
,
val_dataloaders
return
train_dataloader
,
val_dataloaders
nni/retiarii/oneshot/pytorch/supermodule/__init__.py
0 → 100644
View file @
14d2966b
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
0 → 100644
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Thie file handles "slice" commonly used in mixed-operation.
The ``slice_type`` we support here, is "slice" or "list of slice".
The reason is that sometimes (e.g., in multi-head attention),
the tensor slice could be from multiple parts. This type is extensible.
We can support arbitrary masks in future if we need them.
To slice a tensor, we need ``multidim_slice``,
which is simply a tuple consists of ``slice_type``.
Usually in python programs, the variable put into slice's start, stop and step
should be integers (or NoneType).
But in our case, it could also be a dict from integer to float,
representing a distribution of integer. When that happens,
we convert a "slice with some weighted values", to a "weighted slice".
To this end, we track the computation with ``MaybeWeighted``,
and replay the computation with each possible value.
Meanwhile, we record their weights.
Note that ``MaybeWeighted`` is also extensible.
We can support more types of objects on slice in future.
The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor.
"""
import
operator
from
typing
import
Tuple
,
Union
,
List
,
Dict
,
Callable
,
Optional
,
Iterator
,
TypeVar
,
Any
,
Generic
import
numpy
as
np
import
torch
T
=
TypeVar
(
'T'
)
slice_type
=
Union
[
slice
,
List
[
slice
]]
multidim_slice
=
Tuple
[
slice_type
,
...]
scalar_or_scalar_dict
=
Union
[
T
,
Dict
[
T
,
float
]]
int_or_int_dict
=
scalar_or_scalar_dict
[
int
]
_value_fn_type
=
Optional
[
Callable
[[
int_or_int_dict
],
int
]]
def
zeros_like
(
arr
:
T
)
->
T
:
if
isinstance
(
arr
,
np
.
ndarray
):
return
np
.
zeros_like
(
arr
)
elif
isinstance
(
arr
,
torch
.
Tensor
):
return
torch
.
zeros_like
(
arr
)
else
:
raise
TypeError
(
f
'Unsupported type for
{
arr
}
:
{
type
(
arr
)
}
'
)
def
_eliminate_list_slice
(
shape
:
tuple
,
slice_
:
multidim_slice
)
->
multidim_slice
:
# get rid of list of slice
result
=
[]
for
i
in
range
(
len
(
slice_
)):
if
isinstance
(
slice_
[
i
],
list
):
# convert list of slices to mask
mask
=
np
.
zeros
(
shape
[
i
],
dtype
=
np
.
bool
)
for
sl
in
slice_
[
i
]:
mask
[
sl
]
=
1
result
.
append
(
mask
)
else
:
result
.
append
(
slice_
[
i
])
return
tuple
(
result
)
def
_slice_weight
(
weight
:
T
,
slice_
:
Union
[
multidim_slice
,
List
[
Tuple
[
multidim_slice
,
float
]]])
->
T
:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
if
isinstance
(
slice_
,
list
):
# for weighted case, we get the corresponding masks. e.g.,
# {([3:6],): 0.6, ([2:4],): 0.3} => [0, 0, 0.3, 0.9, 0.6, 0.6] (if the whole length is 6)
# this mask is broadcasted and multiplied onto the weight
masks
=
[]
# the accepted argument is list of tuple here
# because slice can't be key of dict
for
sl
,
wt
in
slice_
:
# create a mask with weight w
with
torch
.
no_grad
():
mask
=
zeros_like
(
weight
)
mask
[
_eliminate_list_slice
(
weight
.
shape
,
sl
)]
=
1
# track gradients here
masks
.
append
((
mask
*
wt
))
masks
=
sum
(
masks
)
return
masks
*
weight
else
:
# for unweighted case, we slice it directly.
def
_do_slice
(
arr
,
slice_
):
return
arr
[
_eliminate_list_slice
(
arr
.
shape
,
slice_
)]
# sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr
=
np
.
zeros
(
weight
.
shape
,
dtype
=
np
.
bool
)
no_effect
=
_do_slice
(
dummy_arr
,
slice_
).
shape
==
dummy_arr
.
shape
if
no_effect
:
return
weight
return
_do_slice
(
weight
,
slice_
)
class
Slicable
(
Generic
[
T
]):
"""Wraps the weight so that in can be sliced with a ``multidim_slice``.
The value within the slice can be instances of :class:`MaybeWeighted`.
Examples
--------
>>> weight = conv2d.weight
>>> Slicable(weight)[:MaybeWeighted({32: 0.4, 64: 0.6})]
Tensor of shape (64, 64, 3, 3)
"""
def
__init__
(
self
,
weight
:
T
):
if
not
isinstance
(
weight
,
np
.
ndarray
)
and
not
torch
.
is_tensor
(
weight
):
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
self
.
weight
=
weight
def
__getitem__
(
self
,
index
:
multidim_slice
)
->
T
:
if
not
isinstance
(
index
,
tuple
):
index
=
(
index
,
)
# Get the dict value in index's leafs
# There can be at most one dict
leaf_dict
:
Optional
[
Dict
[
int
,
float
]]
=
None
for
maybe_weighted
in
_iterate_over_multidim_slice
(
index
):
for
d
in
maybe_weighted
.
leaf_values
():
if
isinstance
(
d
,
dict
):
if
leaf_dict
is
None
:
leaf_dict
=
d
elif
leaf_dict
is
not
d
:
raise
ValueError
(
'There can be at most one distinct dict in leaf values.'
)
if
leaf_dict
is
None
:
# in case of simple types with no dict
res_index
=
_evaluate_multidim_slice
(
index
)
else
:
# there is a dict, iterate over dict
res_index
=
[]
for
val
,
wt
in
leaf_dict
.
items
():
res_index_item
=
_evaluate_multidim_slice
(
index
,
lambda
_
:
val
)
res_index
.
append
((
res_index_item
,
wt
))
return
_slice_weight
(
self
.
weight
,
res_index
)
class
MaybeWeighted
:
"""Wrap a value (int or dict with int keys), so that the computation on it can be replayed.
It builds a binary tree. If ``value`` is not None, it's a leaf node.
Otherwise, it has left sub-tree and right sub-tree and an operation.
Only support basic arithmetic operations: ``+``, ``-``, ``*``, ``//``.
"""
def
__init__
(
self
,
value
:
Optional
[
int_or_int_dict
]
=
None
,
*
,
lhs
:
Optional
[
Union
[
'MaybeWeighted'
,
int
]]
=
None
,
rhs
:
Optional
[
Union
[
'MaybeWeighted'
,
int
]]
=
None
,
operation
:
Optional
[
Callable
[[
int
,
int
],
int
]]
=
None
):
if
operation
is
None
:
if
not
isinstance
(
value
,
(
int
,
dict
)):
raise
TypeError
(
f
'Unsupported value type:
{
type
(
value
)
}
'
)
self
.
value
=
value
self
.
lhs
=
lhs
self
.
rhs
=
rhs
self
.
operation
=
operation
def
leaf_values
(
self
)
->
Iterator
[
Dict
[
int
,
float
]]:
"""Iterate over values on leaf nodes."""
if
self
.
value
is
not
None
:
yield
self
.
value
else
:
if
isinstance
(
self
.
lhs
,
MaybeWeighted
):
yield
from
self
.
lhs
.
leaf_values
()
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
yield
from
self
.
rhs
.
leaf_values
()
def
evaluate
(
self
,
value_fn
:
_value_fn_type
=
None
)
->
int
:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
"""
if
self
.
value
is
not
None
:
if
value_fn
is
not
None
:
return
value_fn
(
self
.
value
)
return
self
.
value
else
:
if
isinstance
(
self
.
lhs
,
MaybeWeighted
):
eval_lhs
=
self
.
lhs
.
evaluate
(
value_fn
)
else
:
eval_lhs
=
self
.
lhs
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
eval_rhs
=
self
.
rhs
.
evaluate
(
value_fn
)
else
:
eval_rhs
=
self
.
rhs
return
self
.
operation
(
eval_lhs
,
eval_rhs
)
def
__repr__
(
self
):
if
self
.
value
is
not
None
:
return
f
'
{
self
.
__class__
.
__name__
}
(
{
self
.
value
}
)'
return
f
'
{
self
.
__class__
.
__name__
}
(lhs=
{
self
.
lhs
}
, rhs=
{
self
.
rhs
}
, op=
{
self
.
operation
}
)'
def
__add__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
self
,
rhs
=
other
,
operation
=
operator
.
add
)
def
__radd__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
other
,
rhs
=
self
,
operation
=
operator
.
add
)
def
__sub__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
self
,
rhs
=
other
,
operation
=
operator
.
sub
)
def
__rsub__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
other
,
rhs
=
self
,
operation
=
operator
.
sub
)
def
__mul__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
self
,
rhs
=
other
,
operation
=
operator
.
mul
)
def
__rmul__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
other
,
rhs
=
self
,
operation
=
operator
.
mul
)
def
__floordiv__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
self
,
rhs
=
other
,
operation
=
operator
.
floordiv
)
def
__rfloordiv__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
other
,
rhs
=
self
,
operation
=
operator
.
floordiv
)
def
_iterate_over_slice_type
(
s
:
slice_type
):
if
isinstance
(
s
,
list
):
for
se
in
s
:
yield
from
_iterate_over_slice_type
(
se
)
else
:
# s must be a "slice" now
if
isinstance
(
s
.
start
,
MaybeWeighted
):
yield
s
.
start
if
isinstance
(
s
.
stop
,
MaybeWeighted
):
yield
s
.
stop
if
isinstance
(
s
.
step
,
MaybeWeighted
):
yield
s
.
step
def
_iterate_over_multidim_slice
(
ms
:
multidim_slice
):
"""Get :class:`MaybeWeighted` instances in ``ms``."""
for
s
in
ms
:
if
s
is
not
None
:
yield
from
_iterate_over_slice_type
(
s
)
def
_evaluate_slice_type
(
s
:
slice_type
,
value_fn
:
_value_fn_type
=
None
):
if
isinstance
(
s
,
list
):
return
[
_evaluate_slice_type
(
se
,
value_fn
)
for
se
in
s
]
else
:
return
slice
(
s
.
start
.
evaluate
(
value_fn
)
if
isinstance
(
s
.
start
,
MaybeWeighted
)
else
s
.
start
,
s
.
stop
.
evaluate
(
value_fn
)
if
isinstance
(
s
.
stop
,
MaybeWeighted
)
else
s
.
stop
,
s
.
step
.
evaluate
(
value_fn
)
if
isinstance
(
s
.
step
,
MaybeWeighted
)
else
s
.
step
)
def
_evaluate_multidim_slice
(
ms
:
multidim_slice
,
value_fn
:
_value_fn_type
=
None
):
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
res
=
[]
for
s
in
ms
:
if
s
is
not
None
:
res
.
append
(
_evaluate_slice_type
(
s
,
value_fn
))
else
:
res
.
append
(
None
)
return
tuple
(
res
)
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
0 → 100644
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item.
Keep this file here so that it can be "blamed".
"""
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii.nn.pytorch
import
ValueChoice
class
DifferentiableSuperConv2d
(
nn
.
Conv2d
):
"""
Only ``kernel_size`` ``in_channels`` and ``out_channels`` are supported. Kernel size candidates should be larger or smaller
than each other in both candidates. See examples below:
the following example is not allowed:
>>> ValueChoice(candidates = [(5, 3), (3, 5)])
□ ■ ■ ■ □ □ □ □ □ □
□ ■ ■ ■ □ ■ ■ ■ ■ ■ # candidates are not bigger or smaller on both dimension
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ □ □ □ □ □
the following 3 examples are valid:
>>> ValueChoice(candidates = [5, 3, 1])
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ ■ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
>>> ValueChoice(candidates = [(5, 7), (3, 5), (1, 3)])
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ ■ ■ ■ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
>>> # when the difference between any two candidates is not even, the left upper will be picked:
>>> ValueChoice(candidates = [(5, 5), (4, 4), (3, 3)])
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
"""
def
__init__
(
self
,
module
,
name
):
self
.
label
=
name
args
=
module
.
trace_kwargs
# compulsory params
if
isinstance
(
args
[
'in_channels'
],
ValueChoice
):
args
[
'in_channels'
]
=
max
(
args
[
'in_channels'
].
candidates
)
self
.
out_channel_candidates
=
None
if
isinstance
(
args
[
'out_channels'
],
ValueChoice
):
self
.
out_channel_candidates
=
sorted
(
args
[
'out_channels'
].
candidates
,
reverse
=
True
)
args
[
'out_channels'
]
=
self
.
out_channel_candidates
[
0
]
# kernel_size may be an int or tuple, we turn it into a tuple for simplicity
self
.
kernel_size_candidates
=
None
if
isinstance
(
args
[
'kernel_size'
],
ValueChoice
):
# unify kernel size as tuple
candidates
=
args
[
'kernel_size'
].
candidates
if
not
isinstance
(
candidates
[
0
],
tuple
):
candidates
=
[(
k
,
k
)
for
k
in
candidates
]
# sort kernel size in descending order
self
.
kernel_size_candidates
=
sorted
(
candidates
,
key
=
lambda
t
:
t
[
0
],
reverse
=
True
)
for
i
in
range
(
0
,
len
(
self
.
kernel_size_candidates
)
-
1
):
bigger
=
self
.
kernel_size_candidates
[
i
]
smaller
=
self
.
kernel_size_candidates
[
i
+
1
]
assert
bigger
[
1
]
>
smaller
[
1
]
or
(
bigger
[
1
]
==
smaller
[
1
]
and
bigger
[
0
]
>
smaller
[
0
]),
f
'Kernel_size candidates '
\
f
'should be larger or smaller than each other on both dimensions, but found
{
bigger
}
and
{
smaller
}
.'
args
[
'kernel_size'
]
=
self
.
kernel_size_candidates
[
0
]
super
().
__init__
(
**
args
)
self
.
generate_architecture_params
()
def
forward
(
self
,
input
):
# Note that there is no need to handle ``in_channels`` here since it is already handle by the ``out_channels`` in the
# previous module. If we multiply alpha with refer to ``in_channels`` here again, the alpha will indeed be considered
# twice, which is not what we expect.
weight
=
self
.
weight
def
sum_weight
(
input_weight
,
masks
,
thresholds
,
indicator
):
"""
This is to get the weighted sum of weight.
Parameters
----------
input_weight : Tensor
the weight to be weighted summed
masks : List[Tensor]
weight masks.
thresholds : List[float]
thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight
Returns
----------
weight : Tensor
weighted sum of ``input_weight``. this is of the same shape as ``input_sum``
"""
# Note that ``masks`` and ``thresholds`` have different lengths. There alignment is shown below:
# self.xxx_candidates = [ c_0 , c_1 , ... , c_n-2 , c_n-1 ] # descending order
# self.xxx_mask = [ mask_0 , mask_1 , ... , mask_n-2, mask_n-1]
# self.t_xxx = [ t_0 , t_2 , ... , t_n-2 ]
# So we zip the first n-1 items, and multiply masks[-1] in the end.
weight
=
torch
.
zeros_like
(
input_weight
)
for
mask
,
t
in
zip
(
masks
[:
-
1
],
thresholds
):
cur_part
=
input_weight
*
mask
alpha
=
indicator
(
cur_part
,
t
)
weight
=
(
weight
+
cur_part
)
*
alpha
# we do not consider skip-op here for out_channel/expansion candidates, which means at least the smallest channel
# candidate is included
weight
+=
input_weight
*
masks
[
-
1
]
return
weight
if
self
.
kernel_size_candidates
is
not
None
:
weight
=
sum_weight
(
weight
,
self
.
kernel_masks
,
self
.
t_kernel
,
self
.
Lasso_sigmoid
)
if
self
.
out_channel_candidates
is
not
None
:
weight
=
sum_weight
(
weight
,
self
.
channel_masks
,
self
.
t_expansion
,
self
.
Lasso_sigmoid
)
output
=
self
.
_conv_forward
(
input
,
weight
,
self
.
bias
)
return
output
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
yield
p
def
named_parameters
(
self
):
for
name
,
p
in
super
().
named_parameters
():
if
name
==
'alpha'
:
continue
yield
name
,
p
def
export
(
self
):
"""
result = {
'kernel_size': i,
'out_channels': j
}
which means the best candidate for an argument is the i-th one if candidates are sorted in descending order
"""
result
=
{}
eps
=
1e-5
with
torch
.
no_grad
():
if
self
.
kernel_size_candidates
is
not
None
:
weight
=
torch
.
zeros_like
(
self
.
weight
)
# ascending order
for
i
in
range
(
len
(
self
.
kernel_size_candidates
)
-
2
,
-
1
,
-
1
):
mask
=
self
.
kernel_masks
[
i
]
t
=
self
.
t_kernel
[
i
]
cur_part
=
self
.
weight
*
mask
alpha
=
self
.
Lasso_sigmoid
(
cur_part
,
t
)
if
alpha
<=
eps
:
# takes the smaller one
result
[
'kernel_size'
]
=
self
.
kernel_size_candidates
[
i
+
1
]
break
weight
=
(
weight
+
cur_part
)
*
alpha
if
'kernel_size'
not
in
result
:
result
[
'kernel_size'
]
=
self
.
kernel_size_candidates
[
0
]
else
:
weight
=
self
.
weight
if
self
.
out_channel_candidates
is
not
None
:
for
i
in
range
(
len
(
self
.
out_channel_candidates
)
-
2
,
-
1
,
-
1
):
mask
=
self
.
channel_masks
[
i
]
t
=
self
.
t_expansion
[
i
]
alpha
=
self
.
Lasso_sigmoid
(
weight
*
mask
,
t
)
if
alpha
<=
eps
:
result
[
'out_channels'
]
=
self
.
out_channel_candidates
[
i
+
1
]
if
'out_channels'
not
in
result
:
result
[
'out_channels'
]
=
self
.
out_channel_candidates
[
0
]
return
result
@
staticmethod
def
Lasso_sigmoid
(
matrix
,
t
):
"""
A trick that can make use of both the value of bool(lasso > t) and the gradient of sigmoid(lasso - t)
Parameters
----------
matrix : Tensor
the matrix to calculate lasso norm
t : float
the threshold
"""
lasso
=
torch
.
norm
(
matrix
)
-
t
indicator
=
(
lasso
>
0
).
float
()
# torch.sign(lasso)
with
torch
.
no_grad
():
# indicator = indicator / 2 + .5 # realign indicator from (-1, 1) to (0, 1)
indicator
-=
F
.
sigmoid
(
lasso
)
indicator
+=
F
.
sigmoid
(
lasso
)
return
indicator
def
generate_architecture_params
(
self
):
self
.
alpha
=
{}
if
self
.
kernel_size_candidates
is
not
None
:
# kernel size arch params
self
.
t_kernel
=
nn
.
Parameter
(
torch
.
rand
(
len
(
self
.
kernel_size_candidates
)
-
1
))
self
.
alpha
[
'kernel_size'
]
=
self
.
t_kernel
# kernel size mask
self
.
kernel_masks
=
[]
for
i
in
range
(
0
,
len
(
self
.
kernel_size_candidates
)
-
1
):
big_size
=
self
.
kernel_size_candidates
[
i
]
small_size
=
self
.
kernel_size_candidates
[
i
+
1
]
mask
=
torch
.
zeros_like
(
self
.
weight
)
mask
[:,
:,
:
big_size
[
0
],
:
big_size
[
1
]]
=
1
# if self.weight.shape = (out, in, 7, 7), big_size = (5, 5) and
mask
[:,
:,
:
small_size
[
0
],
:
small_size
[
1
]]
=
0
# small_size = (3, 3), mask will look like:
self
.
kernel_masks
.
append
(
mask
)
# 0 0 0 0 0 0 0
mask
=
torch
.
zeros_like
(
self
.
weight
)
# 0 1 1 1 1 1 0
mask
[:,
:,
:
self
.
kernel_size_candidates
[
-
1
][
0
],
:
self
.
kernel_size_candidates
[
-
1
][
1
]]
=
1
# 0 1 0 0 0 1 0
self
.
kernel_masks
.
append
(
mask
)
# 0 1 0 0 0 1 0
# 0 1 0 0 0 1 0
if
self
.
out_channel_candidates
is
not
None
:
# 0 1 1 1 1 1 0
# out_channel (or expansion) arch params. we do not consider skip-op here, so we # 0 0 0 0 0 0 0
# only generate ``len(self.kernel_size_candidates) - 1 `` thresholds
self
.
t_expansion
=
nn
.
Parameter
(
torch
.
rand
(
len
(
self
.
out_channel_candidates
)
-
1
))
self
.
alpha
[
'out_channels'
]
=
self
.
t_expansion
self
.
channel_masks
=
[]
for
i
in
range
(
0
,
len
(
self
.
out_channel_candidates
)
-
1
):
big_channel
,
small_channel
=
self
.
out_channel_candidates
[
i
],
self
.
out_channel_candidates
[
i
+
1
]
mask
=
torch
.
zeros_like
(
self
.
weight
)
mask
[:
big_channel
]
=
1
mask
[:
small_channel
]
=
0
# if self.weight.shape = (32, in, W, H), big_channel = 16 and small_size = 8, mask will look like:
# 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
self
.
channel_masks
.
append
(
mask
)
mask
=
torch
.
zeros_like
(
self
.
weight
)
mask
[:
self
.
out_channel_candidates
[
-
1
]]
=
1
self
.
channel_masks
.
append
(
mask
)
class
DifferentiableBatchNorm2d
(
nn
.
BatchNorm2d
):
def
__init__
(
self
,
module
,
name
):
self
.
label
=
name
args
=
module
.
trace_kwargs
if
isinstance
(
args
[
'num_features'
],
ValueChoice
):
args
[
'num_features'
]
=
max
(
args
[
'num_features'
].
candidates
)
super
().
__init__
(
**
args
)
# no architecture parameter is needed for BatchNorm2d Layers
self
.
alpha
=
nn
.
Parameter
(
torch
.
tensor
([]))
def
export
(
self
):
"""
No need to export ``BatchNorm2d``. Refer to the ``Conv2d`` layer that has the ``ValueChoice`` as ``out_channels``.
"""
return
-
1
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
0 → 100644
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
import
itertools
from
typing
import
List
,
Any
,
Dict
,
Tuple
,
Optional
,
Union
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
Choice
=
Any
__all__
=
[
'dedup_inner_choices'
,
'evaluate_value_choice_with_dict'
,
'traverse_all_options'
]
def
dedup_inner_choices
(
value_choices
:
List
[
ValueChoiceX
])
->
Dict
[
str
,
ParameterSpec
]:
"""Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``.
"""
result
=
{}
for
value_choice
in
value_choices
:
for
choice
in
value_choice
.
inner_choices
():
param_spec
=
ParameterSpec
(
choice
.
label
,
'choice'
,
choice
.
candidates
,
(
choice
.
label
,
),
True
,
size
=
len
(
choice
.
candidates
))
if
choice
.
label
in
result
:
if
param_spec
!=
result
[
choice
.
label
]:
raise
ValueError
(
'Value choice conflict: same label with different candidates: '
f
'
{
param_spec
}
vs.
{
result
[
choice
.
label
]
}
'
)
else
:
result
[
choice
.
label
]
=
param_spec
return
result
def
evaluate_value_choice_with_dict
(
value_choice
:
ValueChoiceX
,
chosen
:
Dict
[
str
,
Choice
])
->
Any
:
"""To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values,
then feed the values into ``value_choice.evaluate``.
This can be potentially optimized in terms of speed.
Examples
--------
>>> chosen = {"exp_ratio": 3}
>>> evaluate_value_choice_with_dict(value_choice_in, chosen)
48
>>> evaluate_value_choice_with_dict(value_choice_out, chosen)
96
"""
choice_inner_values
=
[]
for
choice
in
value_choice
.
inner_choices
():
if
choice
.
label
not
in
chosen
:
raise
KeyError
(
f
'
{
value_choice
}
depends on a value with key
{
choice
.
label
}
, but not found in
{
chosen
}
'
)
choice_inner_values
.
append
(
chosen
[
choice
.
label
])
return
value_choice
.
evaluate
(
choice_inner_values
)
def
traverse_all_options
(
value_choice
:
ValueChoiceX
,
weights
:
Optional
[
Dict
[
str
,
List
[
float
]]]
=
None
)
->
List
[
Union
[
Tuple
[
Any
,
float
],
Any
]]:
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
Parameters
----------
value_choice : ValueChoiceX
The value choice to traverse.
weights : Optional[Dict[str, List[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
-------
List[Union[Tuple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
"""
# get a dict of {label: list of tuple of choice and weight}
leafs
:
Dict
[
str
,
List
[
Tuple
[
Choice
,
float
]]]
=
{}
for
label
,
param_spec
in
dedup_inner_choices
([
value_choice
]).
items
():
if
weights
is
not
None
:
if
label
not
in
weights
:
raise
KeyError
(
f
'
{
value_choice
}
depends on a weight with key
{
label
}
, but not found in
{
weights
}
'
)
if
len
(
weights
[
label
])
!=
param_spec
.
size
:
raise
KeyError
(
f
'Expect weights with
{
label
}
to be of length
{
param_spec
.
size
}
, but
{
len
(
weights
[
label
])
}
found'
)
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
weights
[
label
]))
else
:
# create a dummy weight of zero, in case that weights are not provided.
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
itertools
.
repeat
(
0.
,
param_spec
.
size
)))
# result is a dict from a option to its weight
result
:
Dict
[
str
,
Optional
[
float
]]
=
{}
labels
,
values
=
list
(
leafs
.
keys
()),
list
(
leafs
.
values
())
if
not
labels
:
raise
ValueError
(
f
'There expects at least one leaf value choice in
{
value_choice
}
, but nothing found'
)
# get all combinations
for
prod_value
in
itertools
.
product
(
*
values
):
# For example,
# prod_value = ((3, 0.1), ("cat", 0.3), ({"in": 5}, 0.5))
# the first dim is chosen value, second dim is probability
# chosen = {"ks": 3, "animal": "cat", "linear_args": {"in": 5}}
# chosen_weight = np.prod([0.1, 0.3, 0.5])
chosen
=
{
label
:
value
[
0
]
for
label
,
value
in
zip
(
labels
,
prod_value
)}
eval_res
=
evaluate_value_choice_with_dict
(
value_choice
,
chosen
)
if
weights
is
None
:
result
[
eval_res
]
=
None
else
:
# we can't use reduce or inplace product here,
# because weight can sometimes be tensors
chosen_weight
=
prod_value
[
0
][
1
]
for
value
in
prod_value
[
1
:]:
if
chosen_weight
is
None
:
chosen_weight
=
value
[
1
]
else
:
chosen_weight
=
chosen_weight
*
value
[
1
]
if
eval_res
in
result
:
result
[
eval_res
]
=
result
[
eval_res
]
+
chosen_weight
else
:
result
[
eval_res
]
=
chosen_weight
if
weights
is
None
:
return
sorted
(
result
.
keys
())
else
:
return
sorted
(
result
.
items
())
nni/retiarii/oneshot/pytorch/supermodule/base.py
0 → 100644
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
Dict
,
Tuple
,
Union
import
torch.nn
as
nn
from
nni.common.hpo_utils
import
ParameterSpec
class
BaseSuperNetModule
(
nn
.
Module
):
"""
Mutated module in super-net.
Usually, the feed-forward of the module itself is undefined.
It has to be resampled with ``resample()`` so that a specific path is selected.
(Sometimes, this is not required. For example, differentiable super-net.)
A super-net module usually corresponds to one sample. But two exceptions:
* A module can have multiple parameter spec. For example, a convolution-2d can sample kernel size, channels at the same time.
* Multiple modules can share one parameter spec. For example, multiple layer choices with the same label.
For value choice compositions, the parameter spec are bounded to the underlying (original) value choices,
rather than their compositions.
"""
def
resample
(
self
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
"""
Resample the super-net module.
Parameters
----------
memo : Dict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns
-------
dict
Sampled result. If nothing new is sampled, it should return an empty dict.
"""
raise
NotImplementedError
()
def
export
(
self
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
Parameters
----------
memo : Dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise
NotImplementedError
()
def
search_space_spec
(
self
)
->
Dict
[
str
,
ParameterSpec
]:
"""
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
For example: ::
{"layer1": ParameterSpec(values=["conv", "pool"])}
"""
raise
NotImplementedError
()
@
classmethod
def
mutate
(
cls
,
module
:
nn
.
Module
,
name
:
str
,
memo
:
Dict
[
str
,
Any
],
mutate_kwargs
:
Dict
[
str
,
Any
])
->
\
Union
[
'BaseSuperNetModule'
,
bool
,
Tuple
[
'BaseSuperNetModule'
,
bool
]]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on.
Parameters
----------
module : nn.Module
The module to be mutated (replaced).
name : str
Name of this module. With full prefix. For example, ``module1.block1.conv``.
memo : dict
Memo to enable sharing parameters among mutated modules. It should be read and written by
mutate functions themselves.
mutate_kwargs : dict
Algo-related hyper-parameters, and some auxiliary information.
Returns
-------
Union[BaseSuperNetModule, bool, Tuple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`nni.retiarii.oneshot.pytorch.base.BaseOneShotLightningModule` for details.
"""
raise
NotImplementedError
()
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
0 → 100644
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
functools
import
warnings
from
typing
import
List
,
Tuple
,
Optional
,
Dict
,
Any
,
Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
.base
import
BaseSuperNetModule
from
.operation
import
MixedOperation
,
MixedOperationSamplingPolicy
from
._valuechoice_utils
import
traverse_all_options
class
GumbelSoftmax
(
nn
.
Softmax
):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
def
__init__
(
self
,
dim
:
Optional
[
int
]
=
-
1
)
->
None
:
super
().
__init__
(
dim
)
self
.
tau
=
1
self
.
hard
=
False
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
gumbel_softmax
(
inputs
,
tau
=
self
.
tau
,
hard
=
self
.
hard
,
dim
=
self
.
dim
)
class
DifferentiableMixedLayer
(
BaseSuperNetModule
):
"""
Mixed layer, in which fprop is decided by a weighted sum of several layers.
Proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__.
The weight ``alpha`` is usually learnable, and optimized on validation dataset.
Differentiable sampling layer requires all operators returning the same shape for one input,
as all outputs will be weighted summed to get the final output.
Parameters
----------
paths : List[Tuple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
op_names : str
Operator names.
label : str
Name of the choice.
"""
_arch_parameter_names
:
List
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
paths
:
List
[
Tuple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
self
.
op_names
=
[]
if
len
(
alpha
)
!=
len
(
paths
):
raise
ValueError
(
f
'The size of alpha (
{
len
(
alpha
)
}
) must match number of candidates (
{
len
(
paths
)
}
).'
)
for
name
,
module
in
paths
:
self
.
add_module
(
name
,
module
)
self
.
op_names
.
append
(
name
)
assert
self
.
op_names
,
'There has to be at least one op to choose from.'
self
.
label
=
label
self
.
_arch_alpha
=
alpha
self
.
_softmax
=
softmax
def
resample
(
self
,
memo
):
"""Do nothing. Differentiable layer doesn't need resample."""
return
{}
def
export
(
self
,
memo
):
"""Choose the operator with the maximum logit."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
torch
.
argmax
(
self
.
_arch_alpha
).
item
()]}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
True
,
size
=
len
(
self
.
op_names
))}
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
LayerChoice
):
size
=
len
(
module
)
if
module
.
label
in
memo
:
alpha
=
memo
[
module
.
label
]
if
len
(
alpha
)
!=
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
module
.
label
}
conflict:
{
len
(
alpha
)
}
vs.
{
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# this can be reinitialized later
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
list
(
module
.
named_children
()),
alpha
,
softmax
,
module
.
label
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
op_results
=
torch
.
stack
([
getattr
(
self
,
op
)(
*
args
,
**
kwargs
)
for
op
in
self
.
op_names
])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
return
torch
.
sum
(
op_results
*
self
.
_softmax
(
self
.
_arch_alpha
).
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
"""Parameters excluding architecture parameters."""
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
"""Named parameters excluding architecture parameters."""
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
().
named_parameters
(
*
args
,
**
kwargs
):
if
any
(
name
==
par_name
for
par_name
in
self
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
class
DifferentiableMixedInput
(
BaseSuperNetModule
):
"""
Mixed input. Forward returns a weighted sum of candidates.
Implementation is very similar to :class:`DifferentiableMixedLayer`.
Parameters
----------
n_candidates : int
Expect number of input candidates.
n_chosen : int
Expect numebr of inputs finally chosen.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
label : str
Name of the choice.
"""
_arch_parameter_names
:
List
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
self
.
n_candidates
=
n_candidates
if
len
(
alpha
)
!=
n_candidates
:
raise
ValueError
(
f
'The size of alpha (
{
len
(
alpha
)
}
) must match number of candidates (
{
n_candidates
}
).'
)
if
n_chosen
is
None
:
warnings
.
warn
(
'Differentiable architecture search does not support choosing multiple inputs. Assuming one.'
,
RuntimeWarning
)
self
.
n_chosen
=
1
self
.
n_chosen
=
n_chosen
self
.
label
=
label
self
.
_softmax
=
softmax
self
.
_arch_alpha
=
alpha
def
resample
(
self
,
memo
):
"""Do nothing. Differentiable layer doesn't need resample."""
return
{}
def
export
(
self
,
memo
):
"""Choose the operator with the top ``n_chosen`` logits."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
chosen
=
sorted
(
torch
.
argsort
(
-
self
.
_arch_alpha
).
cpu
().
numpy
().
tolist
()[:
self
.
n_chosen
])
if
len
(
chosen
)
==
1
:
chosen
=
chosen
[
0
]
return
{
self
.
label
:
chosen
}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
list
(
range
(
self
.
n_candidates
)),
(
self
.
label
,
),
True
,
size
=
self
.
n_candidates
,
chosen_size
=
self
.
n_chosen
)
}
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
InputChoice
):
if
module
.
reduction
not
in
[
'sum'
,
'mean'
]:
raise
ValueError
(
'Only input choice of sum/mean reduction is supported.'
)
size
=
module
.
n_candidates
if
module
.
label
in
memo
:
alpha
=
memo
[
module
.
label
]
if
len
(
alpha
)
!=
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
module
.
label
}
conflict:
{
len
(
alpha
)
}
vs.
{
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# this can be reinitialized later
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
alpha
,
softmax
,
module
.
label
)
def
forward
(
self
,
inputs
):
"""Forward takes a list of input candidates."""
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
return
torch
.
sum
(
inputs
*
self
.
_softmax
(
self
.
_arch_alpha
).
view
(
*
alpha_shape
),
0
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
"""Parameters excluding architecture parameters."""
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
"""Named parameters excluding architecture parameters."""
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
().
named_parameters
(
*
args
,
**
kwargs
):
if
any
(
name
==
par_name
for
par_name
in
self
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
class
MixedOpDifferentiablePolicy
(
MixedOperationSamplingPolicy
):
"""Implementes the differentiable sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters``
filters out multiple parameters with ``_arch_alpha`` as its prefix.
When this class is asked for ``forward_argument``, it returns a distribution,
i.e., a dict from int to float based on its weights.
All the parameters (``_arch_alpha``, ``parameters()``, ``_softmax``) are
saved as attributes of ``operation``, rather than ``self``,
because this class itself is not a ``nn.Module``, and saved parameters here
won't be optimized.
"""
_arch_parameter_names
:
List
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
],
mutate_kwargs
:
Dict
[
str
,
Any
])
->
None
:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation
.
_arch_alpha
=
nn
.
ParameterDict
()
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
name
in
memo
:
alpha
=
memo
[
name
]
if
len
(
alpha
)
!=
spec
.
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
name
}
conflict:
{
len
(
alpha
)
}
vs.
{
spec
.
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
operation
.
_arch_alpha
[
name
]
=
alpha
operation
.
parameters
=
functools
.
partial
(
self
.
parameters
,
self
=
operation
)
# bind self
operation
.
named_parameters
=
functools
.
partial
(
self
.
named_parameters
,
self
=
operation
)
operation
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
@
staticmethod
def
parameters
(
self
,
*
args
,
**
kwargs
):
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
@
staticmethod
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
(
self
.
__class__
,
self
).
named_parameters
(
*
args
,
**
kwargs
):
# pylint: disable=bad-super-call
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
"""Differentiable. Do nothing in resample."""
return
{}
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
result
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
name
in
result
:
continue
chosen_index
=
torch
.
argmax
(
operation
.
_arch_alpha
[
name
]).
item
()
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
Union
[
Dict
[
Any
,
float
],
Any
]:
if
name
in
operation
.
mutable_arguments
:
weights
=
{
label
:
operation
.
_softmax
(
alpha
)
for
label
,
alpha
in
operation
.
_arch_alpha
.
items
()}
return
dict
(
traverse_all_options
(
operation
.
mutable_arguments
[
name
],
weights
=
weights
))
return
operation
.
init_arguments
[
name
]
nni/retiarii/oneshot/pytorch/supermodule/operation.py
0 → 100644
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Operations that support weight sharing at a fine-grained level,
which is commonly known as super-kernel (as in channel search), or weight entanglement.
"""
import
inspect
import
itertools
from
typing
import
Union
,
Tuple
,
Dict
,
List
,
Any
,
Type
,
Optional
,
TypeVar
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
nni.retiarii.nn.pytorch
as
retiarii_nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
is_traceable
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
.base
import
BaseSuperNetModule
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
from
._operation_utils
import
Slicable
as
_S
,
MaybeWeighted
as
_W
,
int_or_int_dict
,
scalar_or_scalar_dict
T
=
TypeVar
(
'T'
)
class
MixedOperationSamplingPolicy
:
"""
Algo-related part for mixed Operation.
:class:`MixedOperation` delegates its resample and export to this policy (or its subclass),
so that one Operation can be easily combined with different kinds of sampling.
One SamplingStrategy corresponds to one mixed operation.
"""
def
__init__
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
],
mutate_kwargs
:
Dict
[
str
,
Any
])
->
None
:
"""At init, the sampling policy can prepare basic parameters,
and store them in operation if they need back propagation.
This init is called in :meth:`BaseSuperNetModule.mutate`, after the mixed operation is created.
So similar to :meth:`BaseSuperNetModule.mutate`,
memo should also be managed (read and written) by the policy itself.
"""
pass
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.resample`."""
raise
NotImplementedError
()
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.export`."""
raise
NotImplementedError
()
def
forward_argument
(
self
,
operation
:
'MixedOperation'
,
name
:
str
)
->
Any
:
"""Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value.
"""
raise
NotImplementedError
()
class
MixedOperation
(
BaseSuperNetModule
):
"""This is the base class for all mixed operations.
It contains commonly used utilities that will ease the effort to write customized mixed oeprations,
i.e., operations with ValueChoice in its arguments.
By design, for a mixed operation to work in a specific algorithm,
at least two classes are needed.
1. One class needs to inherit this class, to control operation-related behavior,
such as how to initialize the operation such that the sampled operation can be its sub-operation.
2. The other one needs to inherit :class:`MixedOperationSamplingPolicy`,
which controls algo-related behavior, such as sampling.
The two classes are linked with ``sampling_policy`` attribute in :class:`MixedOperation`,
whose type is set via ``mixed_op_sampling`` in ``mutate_kwargs`` when
:meth:`MixedOperation.mutate` is called.
With this design, one mixed-operation (e.g., MixedConv2d) can work in multiple algorithms
(e.g., both DARTS and ENAS), saving the engineering effort to rewrite all operations for
each specific algo.
This class should also define a ``bound_type``, to control the matching type in mutate,
an ``argument_list``, to control which arguments can be dynamically used in ``forward``.
This list will also be used in mutate for sanity check.
"""
bound_type
:
Type
[
nn
.
Module
]
# defined in subclass
argument_list
:
List
[
str
]
# defined in subclass
sampling_policy
:
MixedOperationSamplingPolicy
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
)
->
Any
:
"""Get the initialization argument when constructing super-kernel, i.e., calling ``super().__init__()``.
This is often related to specific operator, rather than algo.
For example::
def super_init_argument(self, name, value_choice):
return max(value_choice.candidates)
"""
raise
NotImplementedError
()
def
__post_init__
(
self
)
->
None
:
"""Can be used to validate, or to do extra processing after calling ``__init__``."""
pass
def
forward_with_args
(
self
,
*
args
,
**
kwargs
):
"""To control real fprop. The accepted arguments are ``argument_list``,
appended by forward arguments in the ``bound_type``."""
raise
NotImplementedError
()
def
__init__
(
self
,
module_kwargs
:
Dict
[
str
,
Any
])
->
None
:
# Concerned arguments
self
.
mutable_arguments
:
Dict
[
str
,
ValueChoiceX
]
=
{}
# Useful when retrieving arguments without ValueChoice
self
.
init_arguments
:
Dict
[
str
,
Any
]
=
{
**
module_kwargs
}
self
.
_fill_missing_init_arguments
()
# get init default
super_init_kwargs
=
{}
for
key
,
value
in
module_kwargs
.
items
():
if
isinstance
(
value
,
ValueChoiceX
):
if
key
not
in
self
.
argument_list
:
raise
TypeError
(
f
'Unsupported value choice on argument of
{
self
.
bound_type
}
:
{
key
}
'
)
super_init_kwargs
[
key
]
=
self
.
super_init_argument
(
key
,
value
)
self
.
mutable_arguments
[
key
]
=
value
else
:
super_init_kwargs
[
key
]
=
value
# get all inner leaf value choices
self
.
_space_spec
:
Dict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
(
self
.
mutable_arguments
.
values
())
super
().
__init__
(
**
super_init_kwargs
)
self
.
__post_init__
()
def
resample
(
self
,
memo
):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return
self
.
sampling_policy
.
resample
(
self
,
memo
)
def
export
(
self
,
memo
):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return
self
.
sampling_policy
.
export
(
self
,
memo
)
def
search_space_spec
(
self
):
return
self
.
_space_spec
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
"""Find value choice in module's arguments and replace the whole module"""
has_valuechoice
=
False
if
isinstance
(
module
,
cls
.
bound_type
)
and
is_traceable
(
module
):
for
arg
in
itertools
.
chain
(
module
.
trace_args
,
module
.
trace_kwargs
.
values
()):
if
isinstance
(
arg
,
ValueChoiceX
):
has_valuechoice
=
True
if
has_valuechoice
:
if
module
.
trace_args
:
raise
ValueError
(
'ValueChoice on class arguments cannot appear together with ``trace_args``. '
'Please enable ``kw_only`` on nni.trace.'
)
# save type and kwargs
mixed_op
=
cls
(
module
.
trace_kwargs
)
if
'mixed_op_sampling'
not
in
mutate_kwargs
:
raise
ValueError
(
'Need to sampling policy of mixed op, but not found in `mutate_kwargs`.'
)
policy_cls
:
Type
[
MixedOperationSamplingPolicy
]
=
mutate_kwargs
[
'mixed_op_sampling'
]
# initialize policy class
# this is put in mutate because we need to access memo
mixed_op
.
sampling_policy
=
policy_cls
(
mixed_op
,
memo
,
mutate_kwargs
)
return
mixed_op
def
forward_argument
(
self
,
name
:
str
)
->
Any
:
"""Get the argument used in forward.
This if often related to algo. We redirect this to sampling policy.
"""
return
self
.
sampling_policy
.
forward_argument
(
self
,
name
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""First get sampled arguments, then forward with the sampled arguments (by calling ``forward_with_args``)."""
sampled_args
=
[
self
.
forward_argument
(
name
)
for
name
in
self
.
argument_list
]
return
self
.
forward_with_args
(
*
sampled_args
,
*
args
,
**
kwargs
)
def
_fill_missing_init_arguments
(
self
)
->
None
:
"""Set the unspecified init arguments in ``self.init_arguments``.
For example, in the case of Conv2d, when user didn't specify argument ``stride``,
this method adds ``stride = 1`` in ``self.init_arguments``.
This is implemented by inspecting the init signature of ``bound_type``.
Arguments in complex cases like ``__new__`` or in super-class is not supported.
"""
def
unwrap
(
cls
):
if
not
hasattr
(
cls
,
'__wrapped__'
):
return
cls
return
unwrap
(
cls
.
__wrapped__
)
for
param
in
inspect
.
signature
(
unwrap
(
self
.
bound_type
).
__init__
).
parameters
.
values
():
if
param
.
default
is
not
param
.
empty
and
param
.
name
not
in
self
.
init_arguments
:
self
.
init_arguments
[
param
.
name
]
=
param
.
default
class
MixedLinear
(
MixedOperation
,
nn
.
Linear
):
"""Mixed linear operation.
Supported arguments are:
- ``in_features``
- ``out_features``
Prefix of weight and bias will be sliced.
"""
bound_type
=
retiarii_nn
.
Linear
argument_list
=
[
'in_features'
,
'out_features'
]
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
in_features
:
int_or_int_dict
,
out_features
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
in_features
=
_W
(
in_features
)
out_features
=
_W
(
out_features
)
weight
=
_S
(
self
.
weight
)[:
out_features
]
weight
=
_S
(
weight
)[:,
:
in_features
]
if
self
.
bias
is
None
:
bias
=
self
.
bias
else
:
bias
=
_S
(
self
.
bias
)[:
out_features
]
return
F
.
linear
(
inputs
,
weight
,
bias
)
_int_or_tuple
=
Union
[
int
,
Tuple
[
int
,
int
]]
class
MixedConv2d
(
MixedOperation
,
nn
.
Conv2d
):
"""Mixed conv2d op.
Supported arguments are:
- ``in_channels``
- ``out_channels``
- ``groups`` (only supported in path sampling)
- ``stride`` (only supported in path sampling)
- ``kernel_size``
- ``padding`` (only supported in path sampling)
- ``dilation`` (only supported in path sampling)
``padding`` will be the "max" padding in differentiable mode.
For channels, prefix will be sliced.
For kernels, we take the small kernel from the center and round it to floor (left top). For example ::
max_kernel = 5*5, sampled_kernel = 3*3, then we take [1: 4]
max_kernel = 5*5, sampled_kernel = 2*2, then we take [1: 3]
□ □ □ □ □ □ □ □ □ □
□ ■ ■ ■ □ □ ■ ■ □ □
□ ■ ■ ■ □ □ ■ ■ □ □
□ ■ ■ ■ □ □ □ □ □ □
□ □ □ □ □ □ □ □ □ □
"""
bound_type
=
retiarii_nn
.
Conv2d
argument_list
=
[
'in_channels'
,
'out_channels'
,
'kernel_size'
,
'stride'
,
'padding'
,
'dilation'
,
'groups'
]
@
staticmethod
def
_to_tuple
(
value
:
scalar_or_scalar_dict
[
T
])
->
Tuple
[
T
,
T
]:
if
not
isinstance
(
value
,
tuple
):
return
(
value
,
value
)
return
value
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
if
name
not
in
[
'in_channels'
,
'out_channels'
,
'groups'
,
'stride'
,
'kernel_size'
,
'padding'
,
'dilation'
]:
raise
NotImplementedError
(
f
'Unsupported value choice on argument:
{
name
}
'
)
if
name
==
[
'kernel_size'
,
'padding'
]:
all_sizes
=
set
(
traverse_all_options
(
value_choice
))
if
any
(
isinstance
(
sz
,
tuple
)
for
sz
in
all_sizes
):
# maximum kernel should be calculated on every dimension
return
(
max
(
self
.
_to_tuple
(
sz
)[
0
]
for
sz
in
all_sizes
),
max
(
self
.
_to_tuple
(
sz
)[
1
]
for
sz
in
all_sizes
)
)
else
:
return
max
(
all_sizes
)
elif
name
==
'groups'
:
# minimum groups, maximum kernel
return
min
(
traverse_all_options
(
value_choice
))
else
:
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
in_channels
:
int_or_int_dict
,
out_channels
:
int_or_int_dict
,
kernel_size
:
scalar_or_scalar_dict
[
_int_or_tuple
],
stride
:
_int_or_tuple
,
padding
:
scalar_or_scalar_dict
[
_int_or_tuple
],
dilation
:
int
,
groups
:
int
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
,
groups
]):
raise
ValueError
(
'stride, dilation, groups does not support weighted sampling.'
)
in_channels
=
_W
(
in_channels
)
out_channels
=
_W
(
out_channels
)
# slice prefix
# For groups > 1, we use groups to slice input weights
weight
=
_S
(
self
.
weight
)[:
out_channels
]
weight
=
_S
(
weight
)[:,
:
in_channels
//
groups
]
# slice center
if
isinstance
(
kernel_size
,
dict
):
padding
=
self
.
padding
# max padding, must be a tuple
kernel_a
,
kernel_b
=
self
.
_to_tuple
(
kernel_size
)
kernel_a
,
kernel_b
=
_W
(
kernel_a
),
_W
(
kernel_b
)
max_kernel_a
,
max_kernel_b
=
self
.
kernel_size
# self.kernel_size must be a tuple
kernel_a_left
,
kernel_b_top
=
(
max_kernel_a
-
kernel_a
)
//
2
,
(
max_kernel_b
-
kernel_b
)
//
2
weight
=
_S
(
weight
)[:,
:,
kernel_a_left
:
kernel_a_left
+
kernel_a
,
kernel_b_top
:
kernel_b_top
+
kernel_b
]
bias
=
_S
(
self
.
bias
)[:
out_channels
]
if
self
.
bias
is
not
None
else
None
# The rest parameters only need to be converted to tuple
stride
=
self
.
_to_tuple
(
stride
)
dilation
=
self
.
_to_tuple
(
dilation
)
if
self
.
padding_mode
!=
'zeros'
:
return
F
.
conv2d
(
F
.
pad
(
inputs
,
self
.
_reversed_padding_repeated_twice
,
mode
=
self
.
padding_mode
),
weight
,
bias
,
stride
,
(
0
,
0
),
dilation
,
groups
)
return
F
.
conv2d
(
inputs
,
weight
,
bias
,
stride
,
padding
,
dilation
,
groups
)
class
MixedBatchNorm2d
(
MixedOperation
,
nn
.
BatchNorm2d
):
"""
Mixed BatchNorm2d operation.
Supported arguments are:
- ``num_features``
- ``eps`` (only supported in path sampling)
- ``momentum`` (only supported in path sampling)
For path-sampling, prefix of ``weight``, ``bias``, ``running_mean`` and ``running_var``
are sliced. For weighted cases, the maximum ``num_features`` is used directly.
Momentum is required to be float.
PyTorch BatchNorm supports a case where momentum can be none, which is not supported here.
"""
bound_type
=
retiarii_nn
.
BatchNorm2d
argument_list
=
[
'num_features'
,
'eps'
,
'momentum'
]
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
num_features
:
int_or_int_dict
,
eps
:
float
,
momentum
:
float
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
eps
,
momentum
]):
raise
ValueError
(
'eps, momentum do not support weighted sampling'
)
if
isinstance
(
num_features
,
dict
):
num_features
=
self
.
num_features
weight
,
bias
=
self
.
weight
,
self
.
bias
running_mean
,
running_var
=
self
.
running_mean
,
self
.
running_var
if
num_features
<
self
.
num_features
:
weight
=
weight
[:
num_features
]
bias
=
bias
[:
num_features
]
running_mean
=
running_mean
[:
num_features
]
running_var
=
running_var
[:
num_features
]
if
self
.
training
:
bn_training
=
True
else
:
bn_training
=
(
self
.
running_mean
is
None
)
and
(
self
.
running_var
is
None
)
return
F
.
batch_norm
(
inputs
,
# If buffers are not to be tracked, ensure that they won't be updated
running_mean
if
not
self
.
training
or
self
.
track_running_stats
else
None
,
running_var
if
not
self
.
training
or
self
.
track_running_stats
else
None
,
weight
,
bias
,
bn_training
,
momentum
,
# originally exponential_average_factor in pytorch code
eps
,
)
class
MixedMultiHeadAttention
(
MixedOperation
,
nn
.
MultiheadAttention
):
"""
Mixed multi-head attention.
Supported arguments are:
- ``embed_dim``
- ``num_heads`` (only supported in path sampling)
- ``kdim``
- ``vdim``
- ``dropout`` (only supported in path sampling)
At init, it constructs the largest possible Q, K, V dimension.
At forward, it slices the prefix to weight matrices according to the sampled value.
For ``in_proj_bias`` and ``in_proj_weight``, three parts will be sliced and concatenated together:
``[0, embed_dim)``, ``[max_embed_dim, max_embed_dim + embed_dim)``,
``[max_embed_dim * 2, max_embed_dim * 2 + embed_dim)``.
Warnings
----------
All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``.
"""
bound_type
=
retiarii_nn
.
MultiheadAttention
argument_list
=
[
'embed_dim'
,
'num_heads'
,
'kdim'
,
'vdim'
,
'dropout'
]
def
__post_init__
(
self
):
# sometimes super-class believes qkv have the same embed_dim.
# but actually they do not, because we can have dynamic (mutable) kdim/vdim.
_qkv_same_embed_dim
=
True
for
dimension
in
[
'kdim'
,
'vdim'
]:
if
self
.
init_arguments
[
dimension
]
is
None
:
# must follow embed_dim is this case
continue
if
getattr
(
self
,
dimension
)
==
self
.
embed_dim
and
\
(
dimension
in
self
.
mutable_arguments
or
'embed_dim'
in
self
.
mutable_arguments
):
_qkv_same_embed_dim
=
False
if
self
.
_qkv_same_embed_dim
and
not
_qkv_same_embed_dim
:
self
.
_qkv_same_embed_dim
=
_qkv_same_embed_dim
# adding back missing parameters
# factory_kwargs could be empty for legacy pytorch versions
factory_kwargs
=
{}
if
'device'
in
self
.
init_arguments
:
factory_kwargs
[
'device'
]
=
self
.
init_arguments
[
'device'
]
if
'dtype'
in
self
.
init_arguments
:
factory_kwargs
[
'dtype'
]
=
self
.
init_arguments
[
'dtype'
]
self
.
q_proj_weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
embed_dim
,
self
.
embed_dim
),
**
factory_kwargs
))
self
.
k_proj_weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
embed_dim
,
self
.
kdim
),
**
factory_kwargs
))
self
.
v_proj_weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
embed_dim
,
self
.
vdim
),
**
factory_kwargs
))
self
.
register_parameter
(
'in_proj_weight'
,
None
)
# reset parameters
nn
.
init
.
xavier_uniform_
(
self
.
q_proj_weight
)
nn
.
init
.
xavier_uniform_
(
self
.
k_proj_weight
)
nn
.
init
.
xavier_uniform_
(
self
.
v_proj_weight
)
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
_to_proj_slice
(
self
,
embed_dim
:
_W
)
->
List
[
slice
]:
# slice three parts, corresponding to q, k, v respectively
return
[
slice
(
embed_dim
),
slice
(
self
.
embed_dim
,
self
.
embed_dim
+
embed_dim
),
slice
(
self
.
embed_dim
*
2
,
self
.
embed_dim
*
2
+
embed_dim
)
]
def
forward_with_args
(
self
,
embed_dim
:
int_or_int_dict
,
num_heads
:
int
,
kdim
:
Optional
[
int_or_int_dict
],
vdim
:
Optional
[
int_or_int_dict
],
dropout
:
float
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
need_weights
:
bool
=
True
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
num_heads
,
dropout
]):
raise
ValueError
(
'num_heads, dropout do not support weighted sampling.'
)
# by default, kdim, vdim can be none
if
kdim
is
None
:
kdim
=
embed_dim
if
vdim
is
None
:
vdim
=
embed_dim
qkv_same_embed_dim
=
kdim
==
embed_dim
and
vdim
==
embed_dim
if
getattr
(
self
,
'batch_first'
,
False
):
# for backward compatibility: v1.7 doesn't have batch_first
query
,
key
,
value
=
[
x
.
transpose
(
1
,
0
)
for
x
in
(
query
,
key
,
value
)]
if
isinstance
(
embed_dim
,
dict
):
used_embed_dim
=
self
.
embed_dim
else
:
used_embed_dim
=
embed_dim
embed_dim
=
_W
(
embed_dim
)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias
=
in_proj_weight
=
None
if
self
.
in_proj_bias
is
not
None
:
in_proj_bias
=
_S
(
self
.
in_proj_bias
)[
self
.
_to_proj_slice
(
embed_dim
)]
if
self
.
in_proj_weight
is
not
None
:
in_proj_weight
=
_S
(
self
.
in_proj_weight
)[
self
.
_to_proj_slice
(
embed_dim
),
:
embed_dim
]
bias_k
=
_S
(
self
.
bias_k
)[:,
:,
:
embed_dim
]
if
self
.
bias_k
is
not
None
else
None
bias_v
=
_S
(
self
.
bias_v
)[:,
:,
:
embed_dim
]
if
self
.
bias_v
is
not
None
else
None
out_proj_weight
=
_S
(
self
.
out_proj
.
weight
)[:
embed_dim
,
:
embed_dim
]
out_proj_bias
=
_S
(
self
.
out_proj
.
bias
)[:
embed_dim
]
if
self
.
out_proj
.
bias
is
not
None
else
None
if
not
qkv_same_embed_dim
:
kdim
=
_W
(
kdim
)
vdim
=
_W
(
vdim
)
q_proj
=
_S
(
self
.
q_proj_weight
)[:
embed_dim
,
:
embed_dim
]
k_proj
=
_S
(
self
.
k_proj_weight
)[:
embed_dim
]
k_proj
=
_S
(
k_proj
)[:,
:
kdim
]
v_proj
=
_S
(
self
.
v_proj_weight
)[:
embed_dim
]
v_proj
=
_S
(
v_proj
)[:,
:
vdim
]
# The rest part is basically same as pytorch
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
in_proj_weight
,
in_proj_bias
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
out_proj_bias
,
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
,
use_separate_proj_weight
=
True
,
q_proj_weight
=
q_proj
,
k_proj_weight
=
k_proj
,
v_proj_weight
=
v_proj
)
else
:
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
in_proj_weight
,
in_proj_bias
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
out_proj_bias
,
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
)
if
getattr
(
self
,
'batch_first'
,
False
):
# backward compatibility
return
attn_output
.
transpose
(
1
,
0
),
attn_output_weights
else
:
return
attn_output
,
attn_output_weights
NATIVE_MIXED_OPERATIONS
:
List
[
Type
[
MixedOperation
]]
=
[
MixedLinear
,
MixedConv2d
,
MixedBatchNorm2d
,
MixedMultiHeadAttention
,
]
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
0 → 100644
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Implementation of ProxylessNAS: a hyrbid approach between differentiable and sampling.
The support remains limited. Known limitations include:
- No support for multiple arguments in forward.
- No support for mixed-operation (value choice).
- The code contains duplicates. Needs refactor.
"""
from
typing
import
List
,
Tuple
,
Optional
import
torch
import
torch.nn
as
nn
from
.differentiable
import
DifferentiableMixedLayer
,
DifferentiableMixedInput
class
_ArchGradientFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
binary_gates
,
run_func
,
backward_func
):
ctx
.
run_func
=
run_func
ctx
.
backward_func
=
backward_func
detached_x
=
x
.
detach
()
detached_x
.
requires_grad
=
x
.
requires_grad
with
torch
.
enable_grad
():
output
=
run_func
(
detached_x
)
ctx
.
save_for_backward
(
detached_x
,
output
)
return
output
.
data
@
staticmethod
def
backward
(
ctx
,
grad_output
):
detached_x
,
output
=
ctx
.
saved_tensors
grad_x
=
torch
.
autograd
.
grad
(
output
,
detached_x
,
grad_output
,
only_inputs
=
True
)
# compute gradients w.r.t. binary_gates
binary_grads
=
ctx
.
backward_func
(
detached_x
.
data
,
output
.
data
,
grad_output
.
data
)
return
grad_x
[
0
],
binary_grads
,
None
,
None
class
ProxylessMixedLayer
(
DifferentiableMixedLayer
):
"""Proxyless version of differentiable mixed layer.
It resamples a single-path every time, rather than go through the softmax.
"""
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
paths
:
List
[
Tuple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
paths
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
paths
))
*
1E-3
)
# like sampling-based methods, it has a ``_sampled``.
self
.
_sampled
:
Optional
[
str
]
=
None
self
.
_sample_idx
:
Optional
[
int
]
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
def
run_function
(
ops
,
active_id
,
**
kwargs
):
def
forward
(
_x
):
return
ops
[
active_id
](
_x
,
**
kwargs
)
return
forward
def
backward_function
(
ops
,
active_id
,
binary_gates
,
**
kwargs
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
len
(
ops
)):
if
k
!=
active_id
:
out_k
=
ops
[
k
](
_x
.
data
,
**
kwargs
)
else
:
out_k
=
_output
.
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
assert
len
(
args
)
==
1
,
'ProxylessMixedLayer only supports exactly one input argument.'
x
=
args
[
0
]
assert
self
.
_sampled
is
not
None
,
'Need to call resample() before running fprop.'
list_ops
=
[
getattr
(
self
,
op
)
for
op
in
self
.
op_names
]
return
_ArchGradientFunction
.
apply
(
x
,
self
.
_binary_gates
,
run_function
(
list_ops
,
self
.
_sample_idx
,
**
kwargs
),
backward_function
(
list_ops
,
self
.
_sample_idx
,
self
.
_binary_gates
,
**
kwargs
)
)
def
resample
(
self
,
memo
):
"""Sample one path based on alpha if label is not found in memo."""
if
self
.
label
in
memo
:
self
.
_sampled
=
memo
[
self
.
label
]
self
.
_sample_idx
=
self
.
op_names
.
index
(
self
.
_sampled
)
else
:
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
self
.
_sample_idx
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sampled
=
self
.
op_names
[
self
.
_sample_idx
]
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
self
.
_sample_idx
]
=
1.0
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
torch
.
argmax
(
self
.
_arch_alpha
).
item
()]}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
for
i
in
range
(
len
(
self
.
_arch_alpha
)):
for
j
in
range
(
len
(
self
.
_arch_alpha
)):
self
.
_arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
class
ProxylessMixedInput
(
DifferentiableMixedInput
):
"""Proxyless version of differentiable input choice.
See :class:`ProxylessLayerChoice` for implementation details.
"""
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
n_candidates
,
n_chosen
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
n_candidates
)
*
1E-3
)
self
.
_sampled
:
Optional
[
int
]
=
None
def
forward
(
self
,
inputs
):
def
run_function
(
active_sample
):
return
lambda
x
:
x
[
active_sample
]
def
backward_function
(
binary_gates
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
self
.
n_candidates
):
out_k
=
_x
[
k
].
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
inputs
=
torch
.
stack
(
inputs
,
0
)
assert
self
.
_sampled
is
not
None
,
'Need to call resample() before running fprop.'
return
_ArchGradientFunction
.
apply
(
inputs
,
self
.
_binary_gates
,
run_function
(
self
.
_sampled
),
backward_function
(
self
.
_binary_gates
)
)
def
resample
(
self
,
memo
):
"""Sample one path based on alpha if label is not found in memo."""
if
self
.
label
in
memo
:
self
.
_sampled
=
memo
[
self
.
label
]
else
:
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sampled
=
sample
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
sample
]
=
1.0
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
torch
.
argmax
(
self
.
_arch_alpha
).
item
()}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
for
i
in
range
(
self
.
n_candidates
):
for
j
in
range
(
self
.
n_candidates
):
self
.
_arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
0 → 100644
View file @
14d2966b
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
random
from
typing
import
Optional
,
List
,
Tuple
,
Union
,
Dict
,
Any
import
torch
import
torch.nn
as
nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
.base
import
BaseSuperNetModule
from
._valuechoice_utils
import
evaluate_value_choice_with_dict
from
.operation
import
MixedOperationSamplingPolicy
,
MixedOperation
class
PathSamplingLayer
(
BaseSuperNetModule
):
"""
Mixed layer, in which fprop is decided by exactly one inner layer or sum of multiple (sampled) layers.
If multiple modules are selected, the result will be summed and returned.
Attributes
----------
_sampled : int or list of str
Sampled module indices.
label : str
Name of the choice.
"""
def
__init__
(
self
,
paths
:
List
[
Tuple
[
str
,
nn
.
Module
]],
label
:
str
):
super
().
__init__
()
self
.
op_names
=
[]
for
name
,
module
in
paths
:
self
.
add_module
(
name
,
module
)
self
.
op_names
.
append
(
name
)
assert
self
.
op_names
,
'There has to be at least one op to choose from.'
self
.
_sampled
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# sampled can be either a list of indices or an index
self
.
label
=
label
def
resample
(
self
,
memo
):
"""Random choose one path if label is not found in memo."""
if
self
.
label
in
memo
:
self
.
_sampled
=
memo
[
self
.
label
]
else
:
self
.
_sampled
=
random
.
choice
(
self
.
op_names
)
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Random choose one name if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
random
.
choice
(
self
.
op_names
)}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
True
,
size
=
len
(
self
.
op_names
))}
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
LayerChoice
):
return
cls
(
list
(
module
.
named_children
()),
module
.
label
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one path needs to be sampled before fprop.'
)
sampled
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
res
=
[
getattr
(
self
,
str
(
samp
))(
*
args
,
**
kwargs
)
for
samp
in
sampled
]
if
len
(
res
)
==
1
:
return
res
[
0
]
else
:
return
sum
(
res
)
class
PathSamplingInput
(
BaseSuperNetModule
):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.
Attributes
----------
_sampled : int or list of int
Sampled input indices.
"""
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
,
reduction
:
str
,
label
:
str
):
super
().
__init__
()
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
_sampled
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
self
.
label
=
label
def
_random_choose_n
(
self
):
sampling
=
list
(
range
(
self
.
n_candidates
))
random
.
shuffle
(
sampling
)
sampling
=
sorted
(
sampling
[:
self
.
n_chosen
])
if
len
(
sampling
)
==
1
:
return
sampling
[
0
]
else
:
return
sampling
def
resample
(
self
,
memo
):
"""Random choose one path / multiple paths if label is not found in memo.
If one path is selected, only one integer will be in ``self._sampled``.
If multiple paths are selected, a list will be in ``self._sampled``.
"""
if
self
.
label
in
memo
:
self
.
_sampled
=
memo
[
self
.
label
]
else
:
self
.
_sampled
=
self
.
_random_choose_n
()
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Random choose one name if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
_random_choose_n
()}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
list
(
range
(
self
.
n_candidates
)),
(
self
.
label
,
),
True
,
size
=
self
.
n_candidates
,
chosen_size
=
self
.
n_chosen
)
}
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
InputChoice
):
if
module
.
reduction
not
in
[
'sum'
,
'mean'
,
'concat'
]:
raise
ValueError
(
'Only input choice of sum/mean/concat reduction is supported.'
)
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
module
.
reduction
,
module
.
label
)
def
forward
(
self
,
input_tensors
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one path needs to be sampled before fprop.'
)
if
len
(
input_tensors
)
!=
self
.
n_candidates
:
raise
ValueError
(
f
'Expect
{
self
.
n_candidates
}
input tensors, found
{
len
(
input_tensors
)
}
.'
)
sampled
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
res
=
[
input_tensors
[
samp
]
for
samp
in
sampled
]
if
len
(
res
)
==
1
:
return
res
[
0
]
else
:
if
self
.
reduction
==
'sum'
:
return
sum
(
res
)
elif
self
.
reduction
==
'mean'
:
return
sum
(
res
)
/
len
(
res
)
elif
self
.
reduction
==
'concat'
:
return
torch
.
cat
(
res
,
1
)
class
MixedOpPathSamplingPolicy
(
MixedOperationSamplingPolicy
):
"""Implementes the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices".
We sample the leaf nodes, and composits them into the values on arguments.
"""
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
],
mutate_kwargs
:
Dict
[
str
,
Any
])
->
None
:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self
.
_sampled
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
"""Random sample for each leaf value choice."""
result
=
{}
space_spec
=
operation
.
search_space_spec
()
for
label
in
space_spec
:
if
label
in
memo
:
result
[
label
]
=
memo
[
label
]
else
:
result
[
label
]
=
random
.
choice
(
space_spec
[
label
].
values
)
# composits to kwargs
# example: result = {"exp_ratio": 3}, self._sampled = {"in_channels": 48, "out_channels": 96}
self
.
_sampled
=
{}
for
key
,
value
in
operation
.
mutable_arguments
.
items
():
self
.
_sampled
[
key
]
=
evaluate_value_choice_with_dict
(
value
,
result
)
return
result
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
Dict
[
str
,
Any
]
=
None
)
->
Dict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
result
=
{}
space_spec
=
operation
.
search_space_spec
()
for
label
in
space_spec
:
if
label
not
in
memo
:
result
[
label
]
=
random
.
choice
(
space_spec
[
label
].
values
)
return
result
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
Any
:
if
self
.
_sampled
is
None
:
raise
ValueError
(
'Need to call resample() before running forward'
)
if
name
in
operation
.
mutable_arguments
:
return
self
.
_sampled
[
name
]
return
operation
.
init_arguments
[
name
]
nni/retiarii/strategy/__init__.py
View file @
14d2966b
...
@@ -7,4 +7,4 @@ from .evolution import RegularizedEvolution
...
@@ -7,4 +7,4 @@ from .evolution import RegularizedEvolution
from
.tpe_strategy
import
TPEStrategy
from
.tpe_strategy
import
TPEStrategy
from
.local_debug_strategy
import
_LocalDebugStrategy
from
.local_debug_strategy
import
_LocalDebugStrategy
from
.rl
import
PolicyBasedRL
from
.rl
import
PolicyBasedRL
from
.oneshot
import
DARTS
,
Proxyless
,
SNA
S
,
ENAS
,
RandomOneShot
from
.oneshot
import
DARTS
,
Proxyless
,
GumbelDART
S
,
ENAS
,
RandomOneShot
nni/retiarii/strategy/oneshot.py
View file @
14d2966b
...
@@ -5,7 +5,7 @@ from .base import BaseStrategy
...
@@ -5,7 +5,7 @@ from .base import BaseStrategy
try
:
try
:
from
nni.retiarii.oneshot.pytorch.strategy
import
(
# pylint: disable=unused-import
from
nni.retiarii.oneshot.pytorch.strategy
import
(
# pylint: disable=unused-import
DARTS
,
SNA
S
,
Proxyless
,
ENAS
,
RandomOneShot
DARTS
,
GumbelDART
S
,
Proxyless
,
ENAS
,
RandomOneShot
)
)
except
ImportError
as
import_err
:
except
ImportError
as
import_err
:
_import_err
=
import_err
_import_err
=
import_err
...
@@ -16,7 +16,7 @@ except ImportError as import_err:
...
@@ -16,7 +16,7 @@ except ImportError as import_err:
# otherwise typing check will pointing to the wrong location
# otherwise typing check will pointing to the wrong location
globals
()[
'DARTS'
]
=
ImportFailedStrategy
globals
()[
'DARTS'
]
=
ImportFailedStrategy
globals
()[
'
SNA
S'
]
=
ImportFailedStrategy
globals
()[
'
GumbelDART
S'
]
=
ImportFailedStrategy
globals
()[
'Proxyless'
]
=
ImportFailedStrategy
globals
()[
'Proxyless'
]
=
ImportFailedStrategy
globals
()[
'ENAS'
]
=
ImportFailedStrategy
globals
()[
'ENAS'
]
=
ImportFailedStrategy
globals
()[
'RandomOneShot'
]
=
ImportFailedStrategy
globals
()[
'RandomOneShot'
]
=
ImportFailedStrategy
test/ut/retiarii/test_oneshot.py
View file @
14d2966b
import
argparse
import
argparse
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
pytest
import
pytest
from
torchvision
import
transforms
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
from
torchvision.datasets
import
MNIST
from
torch.utils.data
.sampler
import
RandomSampler
from
torch.utils.data
import
Dataset
,
RandomSampler
from
nni.retiarii
import
strategy
,
model_wrapper
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
strategy
,
model_wrapper
,
basic_unit
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
DataLoader
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
Regression
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ValueChoice
class
DepthwiseSeparableConv
(
nn
.
Module
):
class
DepthwiseSeparableConv
(
nn
.
Module
):
...
@@ -25,107 +25,261 @@ class DepthwiseSeparableConv(nn.Module):
...
@@ -25,107 +25,261 @@ class DepthwiseSeparableConv(nn.Module):
@
model_wrapper
@
model_wrapper
class
Net
(
pl
.
Lightning
Module
):
class
SimpleNet
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
,
value_choice
=
True
):
super
().
__init__
()
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
3
,
1
)
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
3
,
1
)
self
.
conv2
=
LayerChoice
([
self
.
conv2
=
LayerChoice
([
nn
.
Conv2d
(
32
,
64
,
3
,
1
),
nn
.
Conv2d
(
32
,
64
,
3
,
1
),
DepthwiseSeparableConv
(
32
,
64
)
DepthwiseSeparableConv
(
32
,
64
)
])
])
self
.
dropout1
=
nn
.
Dropout
(.
25
)
self
.
dropout1
=
LayerChoice
([
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
nn
.
Dropout
(.
25
),
self
.
dropout_choice
=
InputChoice
(
2
,
1
)
nn
.
Dropout
(.
5
),
self
.
fc
=
LayerChoice
([
nn
.
Dropout
(.
75
)
nn
.
Sequential
(
nn
.
Linear
(
9216
,
64
),
nn
.
ReLU
(),
nn
.
Linear
(
64
,
10
),
),
nn
.
Sequential
(
nn
.
Linear
(
9216
,
128
),
nn
.
ReLU
(),
nn
.
Linear
(
128
,
10
),
),
nn
.
Sequential
(
nn
.
Linear
(
9216
,
256
),
nn
.
ReLU
(),
nn
.
Linear
(
256
,
10
),
)
])
])
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
if
value_choice
:
hidden
=
nn
.
ValueChoice
([
32
,
64
,
128
])
else
:
hidden
=
64
self
.
fc1
=
nn
.
Linear
(
9216
,
hidden
)
self
.
fc2
=
nn
.
Linear
(
hidden
,
10
)
self
.
rpfc
=
nn
.
Linear
(
10
,
10
)
self
.
rpfc
=
nn
.
Linear
(
10
,
10
)
self
.
input_ch
=
InputChoice
(
2
,
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x1
=
torch
.
flatten
(
self
.
dropout1
(
x
),
1
)
x
=
torch
.
flatten
(
self
.
dropout1
(
x
),
1
)
x2
=
torch
.
flatten
(
self
.
dropout2
(
x
),
1
)
x
=
self
.
fc1
(
x
)
x
=
self
.
dropout_choice
([
x1
,
x2
])
x
=
F
.
relu
(
x
)
x
=
self
.
fc
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
rpfc
(
x
)
x
=
self
.
fc2
(
x
)
x1
=
self
.
rpfc
(
x
)
x
=
self
.
input_ch
([
x
,
x1
])
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
return
output
def
prepare_model_data
():
@
model_wrapper
base_model
=
Net
()
class
MultiHeadAttentionNet
(
nn
.
Module
):
def
__init__
(
self
,
head_count
):
super
().
__init__
()
embed_dim
=
ValueChoice
(
candidates
=
[
32
,
64
])
self
.
linear1
=
nn
.
Linear
(
128
,
embed_dim
)
self
.
mhatt
=
nn
.
MultiheadAttention
(
embed_dim
,
head_count
)
self
.
linear2
=
nn
.
Linear
(
embed_dim
,
1
)
def
forward
(
self
,
batch
):
query
,
key
,
value
=
batch
q
,
k
,
v
=
self
.
linear1
(
query
),
self
.
linear1
(
key
),
self
.
linear1
(
value
)
output
,
_
=
self
.
mhatt
(
q
,
k
,
v
,
need_weights
=
False
)
y
=
self
.
linear2
(
output
)
return
F
.
relu
(
y
)
@
model_wrapper
class
ValueChoiceConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
ch1
=
ValueChoice
([
16
,
32
])
kernel
=
ValueChoice
([
3
,
5
])
self
.
conv1
=
nn
.
Conv2d
(
1
,
ch1
,
kernel
,
padding
=
kernel
//
2
)
self
.
batch_norm
=
nn
.
BatchNorm2d
(
ch1
)
self
.
conv2
=
nn
.
Conv2d
(
ch1
,
64
,
3
)
self
.
dropout1
=
LayerChoice
([
nn
.
Dropout
(.
25
),
nn
.
Dropout
(.
5
),
nn
.
Dropout
(.
75
)
])
self
.
fc
=
nn
.
Linear
(
64
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
batch_norm
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x
=
torch
.
mean
(
x
,
(
2
,
3
))
x
=
self
.
fc
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
@
model_wrapper
class
RepeatNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
ch1
=
ValueChoice
([
16
,
32
])
kernel
=
ValueChoice
([
3
,
5
])
self
.
conv1
=
nn
.
Conv2d
(
1
,
ch1
,
kernel
,
padding
=
kernel
//
2
)
self
.
batch_norm
=
nn
.
BatchNorm2d
(
ch1
)
self
.
conv2
=
nn
.
Conv2d
(
ch1
,
64
,
3
,
padding
=
1
)
self
.
dropout1
=
LayerChoice
([
nn
.
Dropout
(.
25
),
nn
.
Dropout
(.
5
),
nn
.
Dropout
(.
75
)
])
self
.
fc
=
nn
.
Linear
(
64
,
10
)
self
.
rpfc
=
nn
.
Repeat
(
nn
.
Linear
(
10
,
10
),
(
1
,
4
))
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
batch_norm
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x
=
torch
.
mean
(
x
,
(
2
,
3
))
x
=
self
.
fc
(
x
)
x
=
self
.
rpfc
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
@
basic_unit
class
MyOp
(
nn
.
Module
):
def
__init__
(
self
,
some_ch
):
super
().
__init__
()
self
.
some_ch
=
some_ch
self
.
batch_norm
=
nn
.
BatchNorm2d
(
some_ch
)
def
forward
(
self
,
x
):
return
self
.
batch_norm
(
x
)
@
model_wrapper
class
CustomOpValueChoiceNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
ch1
=
ValueChoice
([
16
,
32
])
kernel
=
ValueChoice
([
3
,
5
])
self
.
conv1
=
nn
.
Conv2d
(
1
,
ch1
,
kernel
,
padding
=
kernel
//
2
)
self
.
batch_norm
=
MyOp
(
ch1
)
self
.
conv2
=
nn
.
Conv2d
(
ch1
,
64
,
3
,
padding
=
1
)
self
.
dropout1
=
LayerChoice
([
nn
.
Dropout
(.
25
),
nn
.
Dropout
(.
5
),
nn
.
Dropout
(.
75
)
])
self
.
fc
=
nn
.
Linear
(
64
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
batch_norm
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x
=
torch
.
mean
(
x
,
(
2
,
3
))
x
=
self
.
fc
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
_mnist_net
(
type_
):
if
type_
==
'simple'
:
base_model
=
SimpleNet
(
False
)
elif
type_
==
'simple_value_choice'
:
base_model
=
SimpleNet
()
elif
type_
==
'value_choice'
:
base_model
=
ValueChoiceConvNet
()
elif
type_
==
'repeat'
:
base_model
=
RepeatNet
()
elif
type_
==
'custom_op'
:
base_model
=
CustomOpValueChoiceNet
()
else
:
raise
ValueError
(
f
'Unsupported type:
{
type_
}
'
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
10
))
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
20
))
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
10
))
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
20
))
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
evaluator
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
max_epochs
=
1
)
trainer_kwargs
=
{
return
base_model
,
evaluator
'max_epochs'
:
1
}
def
_multihead_attention_net
():
base_model
=
MultiHeadAttentionNet
(
1
)
class
AttentionRandDataset
(
Dataset
):
def
__init__
(
self
,
data_shape
,
gt_shape
,
len
)
->
None
:
super
().
__init__
()
self
.
datashape
=
data_shape
self
.
gtshape
=
gt_shape
self
.
len
=
len
return
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
def
__getitem__
(
self
,
index
):
q
=
torch
.
rand
(
self
.
datashape
)
k
=
torch
.
rand
(
self
.
datashape
)
v
=
torch
.
rand
(
self
.
datashape
)
gt
=
torch
.
rand
(
self
.
gtshape
)
return
(
q
,
k
,
v
),
gt
def
__len__
(
self
):
return
self
.
len
def
_test_strategy
(
strategy_
):
train_set
=
AttentionRandDataset
((
1
,
128
),
(
1
,
1
),
1000
)
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
val_set
=
AttentionRandDataset
((
1
,
128
),
(
1
,
1
),
500
)
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
train_loader
=
DataLoader
(
train_set
,
batch_size
=
32
)
experiment
=
RetiariiExperiment
(
base_model
,
cls
,
strategy
=
strategy_
)
val_loader
=
DataLoader
(
val_set
,
batch_size
=
32
)
evaluator
=
Regression
(
train_dataloader
=
train_loader
,
val_dataloaders
=
val_loader
,
max_epochs
=
1
)
return
base_model
,
evaluator
def
_test_strategy
(
strategy_
,
support_value_choice
=
True
):
to_test
=
[
# (model, evaluator), support_or_net
(
_mnist_net
(
'simple'
),
True
),
(
_mnist_net
(
'simple_value_choice'
),
support_value_choice
),
(
_mnist_net
(
'value_choice'
),
support_value_choice
),
(
_mnist_net
(
'repeat'
),
False
),
# no strategy supports repeat currently
(
_mnist_net
(
'custom_op'
),
False
),
# this is definitely a NO
(
_multihead_attention_net
(),
support_value_choice
),
]
for
(
base_model
,
evaluator
),
support_or_not
in
to_test
:
print
(
'Testing:'
,
type
(
strategy_
).
__name__
,
type
(
base_model
).
__name__
,
type
(
evaluator
).
__name__
,
support_or_not
)
experiment
=
RetiariiExperiment
(
base_model
,
evaluator
,
strategy
=
strategy_
)
config
=
RetiariiExeConfig
()
config
=
RetiariiExeConfig
()
config
.
execution_engine
=
'oneshot'
config
.
execution_engine
=
'oneshot'
if
support_or_not
:
experiment
.
run
(
config
)
experiment
.
run
(
config
)
assert
isinstance
(
experiment
.
export_top_models
()[
0
],
dict
)
assert
isinstance
(
experiment
.
export_top_models
()[
0
],
dict
)
else
:
with
pytest
.
raises
(
TypeError
,
match
=
'not supported'
):
experiment
.
run
(
config
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_darts
():
def
test_darts
():
_test_strategy
(
strategy
.
DARTS
())
_test_strategy
(
strategy
.
DARTS
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_proxyless
():
def
test_proxyless
():
_test_strategy
(
strategy
.
Proxyless
())
_test_strategy
(
strategy
.
Proxyless
()
,
False
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_enas
():
def
test_enas
():
_test_strategy
(
strategy
.
ENAS
())
_test_strategy
(
strategy
.
ENAS
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_random
():
def
test_random
():
_test_strategy
(
strategy
.
RandomOneShot
())
_test_strategy
(
strategy
.
RandomOneShot
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_
sna
s
():
def
test_
gumbel_dart
s
():
_test_strategy
(
strategy
.
SNA
S
())
_test_strategy
(
strategy
.
GumbelDART
S
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--exp'
,
type
=
str
,
default
=
'all'
,
metavar
=
'E'
,
parser
.
add_argument
(
'--exp'
,
type
=
str
,
default
=
'all'
,
metavar
=
'E'
,
help
=
'exp to run, default = all'
)
help
=
'exp
eriment
to run, default = all'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
exp
==
'all'
:
if
args
.
exp
==
'all'
:
...
@@ -133,6 +287,6 @@ if __name__ == '__main__':
...
@@ -133,6 +287,6 @@ if __name__ == '__main__':
test_proxyless
()
test_proxyless
()
test_enas
()
test_enas
()
test_random
()
test_random
()
test_
sna
s
()
test_
gumbel_dart
s
()
else
:
else
:
globals
()[
f
'test_
{
args
.
exp
}
'
]()
globals
()[
f
'test_
{
args
.
exp
}
'
]()
test/ut/retiarii/test_oneshot_supermodules.py
0 → 100644
View file @
14d2966b
import
pytest
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
Conv2d
,
BatchNorm2d
,
Linear
,
MultiheadAttention
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
(
MixedOpDifferentiablePolicy
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
GumbelSoftmax
)
from
nni.retiarii.oneshot.pytorch.supermodule.sampling
import
(
MixedOpPathSamplingPolicy
,
PathSamplingLayer
,
PathSamplingInput
)
from
nni.retiarii.oneshot.pytorch.supermodule.operation
import
MixedConv2d
,
NATIVE_MIXED_OPERATIONS
from
nni.retiarii.oneshot.pytorch.supermodule.proxyless
import
ProxylessMixedLayer
,
ProxylessMixedInput
from
nni.retiarii.oneshot.pytorch.supermodule._operation_utils
import
Slicable
as
S
,
MaybeWeighted
as
W
from
nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils
import
*
def
test_slice
():
weight
=
np
.
ones
((
3
,
7
,
24
,
23
))
assert
S
(
weight
)[:,
1
:
3
,
:,
9
:
13
].
shape
==
(
3
,
2
,
24
,
4
)
assert
S
(
weight
)[:,
1
:
W
(
3
)
*
2
+
1
,
:,
9
:
13
].
shape
==
(
3
,
6
,
24
,
4
)
assert
S
(
weight
)[:,
1
:
W
(
3
)
*
2
+
1
].
shape
==
(
3
,
6
,
24
,
23
)
# no effect
assert
S
(
weight
)[:]
is
weight
# list
assert
S
(
weight
)[[
slice
(
1
),
slice
(
2
,
3
)]].
shape
==
(
2
,
7
,
24
,
23
)
assert
S
(
weight
)[[
slice
(
1
),
slice
(
2
,
W
(
2
)
+
1
)],
W
(
2
):].
shape
==
(
2
,
5
,
24
,
23
)
# weighted
weight
=
S
(
weight
)[:
W
({
1
:
0.5
,
2
:
0.3
,
3
:
0.2
})]
weight
=
weight
[:,
0
,
0
,
0
]
assert
weight
[
0
]
==
1
and
weight
[
1
]
==
0.5
and
weight
[
2
]
==
0.2
weight
=
np
.
ones
((
3
,
6
,
6
))
value
=
W
({
1
:
0.5
,
3
:
0.5
})
weight
=
S
(
weight
)[:,
3
-
value
:
3
+
value
,
3
-
value
:
3
+
value
]
for
i
in
range
(
0
,
6
):
for
j
in
range
(
0
,
6
):
if
2
<=
i
<=
3
and
2
<=
j
<=
3
:
assert
weight
[
0
,
i
,
j
]
==
1
else
:
assert
weight
[
1
,
i
,
j
]
==
0.5
# weighted + list
value
=
W
({
1
:
0.5
,
3
:
0.5
})
weight
=
np
.
ones
((
8
,
4
))
weight
=
S
(
weight
)[[
slice
(
value
),
slice
(
4
,
value
+
4
)]]
assert
weight
.
sum
(
1
).
tolist
()
==
[
4
,
2
,
2
,
0
,
4
,
2
,
2
,
0
]
with
pytest
.
raises
(
ValueError
,
match
=
'one distinct'
):
# has to be exactly the same instance, equal is not enough
weight
=
S
(
weight
)[:
W
({
1
:
0.5
}),
:
W
({
1
:
0.5
})]
def
test_valuechoice_utils
():
chosen
=
{
"exp"
:
3
,
"add"
:
1
}
vc0
=
ValueChoice
([
3
,
4
,
6
],
label
=
'exp'
)
*
2
+
ValueChoice
([
0
,
1
],
label
=
'add'
)
assert
evaluate_value_choice_with_dict
(
vc0
,
chosen
)
==
7
vc
=
vc0
+
ValueChoice
([
3
,
4
,
6
],
label
=
'exp'
)
assert
evaluate_value_choice_with_dict
(
vc
,
chosen
)
==
10
assert
list
(
dedup_inner_choices
([
vc0
,
vc
]).
keys
())
==
[
'exp'
,
'add'
]
assert
traverse_all_options
(
vc
)
==
[
9
,
10
,
12
,
13
,
18
,
19
]
weights
=
dict
(
traverse_all_options
(
vc
,
weights
=
{
'exp'
:
[
0.5
,
0.3
,
0.2
],
'add'
:
[
0.4
,
0.6
]}))
ans
=
dict
([(
9
,
0.2
),
(
10
,
0.3
),
(
12
,
0.12
),
(
13
,
0.18
),
(
18
,
0.08
),
(
19
,
0.12
)])
assert
len
(
weights
)
==
len
(
ans
)
for
value
,
weight
in
ans
.
items
():
assert
abs
(
weight
-
weights
[
value
])
<
1e-6
def
test_pathsampling_valuechoice
():
orig_conv
=
Conv2d
(
3
,
ValueChoice
([
3
,
5
,
7
],
label
=
'123'
),
kernel_size
=
3
)
conv
=
MixedConv2d
.
mutate
(
orig_conv
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
conv
.
resample
(
memo
=
{
'123'
:
5
})
assert
conv
(
torch
.
zeros
((
1
,
3
,
5
,
5
))).
size
(
1
)
==
5
conv
.
resample
(
memo
=
{
'123'
:
7
})
assert
conv
(
torch
.
zeros
((
1
,
3
,
5
,
5
))).
size
(
1
)
==
7
assert
conv
.
export
({})[
'123'
]
in
[
3
,
5
,
7
]
def
test_differentiable_valuechoice
():
orig_conv
=
Conv2d
(
3
,
ValueChoice
([
3
,
5
,
7
],
label
=
'456'
),
kernel_size
=
ValueChoice
(
[
3
,
5
,
7
],
label
=
'123'
),
padding
=
ValueChoice
([
3
,
5
,
7
],
label
=
'123'
)
//
2
)
conv
=
MixedConv2d
.
mutate
(
orig_conv
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
})
assert
conv
(
torch
.
zeros
((
1
,
3
,
7
,
7
))).
size
(
2
)
==
7
assert
set
(
conv
.
export
({}).
keys
())
==
{
'123'
,
'456'
}
def
_mixed_operation_sampling_sanity_check
(
operation
,
memo
,
*
input
):
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
if
native_op
.
bound_type
==
type
(
operation
):
mutate_op
=
native_op
.
mutate
(
operation
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
break
mutate_op
.
resample
(
memo
=
memo
)
return
mutate_op
(
*
input
)
def
_mixed_operation_differentiable_sanity_check
(
operation
,
*
input
):
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
if
native_op
.
bound_type
==
type
(
operation
):
mutate_op
=
native_op
.
mutate
(
operation
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
})
break
return
mutate_op
(
*
input
)
def
test_mixed_linear
():
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'shared'
),
ValueChoice
([
2
,
4
,
8
]))
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
3
},
torch
.
randn
(
2
,
3
))
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
9
},
torch
.
randn
(
2
,
9
))
_mixed_operation_differentiable_sanity_check
(
linear
,
torch
.
randn
(
2
,
9
))
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'shared'
),
ValueChoice
([
2
,
4
,
8
]),
bias
=
False
)
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
3
},
torch
.
randn
(
2
,
3
))
with
pytest
.
raises
(
TypeError
):
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'shared'
),
ValueChoice
([
2
,
4
,
8
]),
bias
=
ValueChoice
([
False
,
True
]))
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
3
},
torch
.
randn
(
2
,
3
))
def
test_mixed_conv2d
():
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
2
,
4
,
8
],
label
=
'out'
)
*
2
,
1
)
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
3
,
'out'
:
4
},
torch
.
randn
(
2
,
3
,
9
,
9
)).
size
(
1
)
==
8
_mixed_operation_differentiable_sanity_check
(
conv
,
torch
.
randn
(
2
,
9
,
3
,
3
))
# stride
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
2
,
4
,
8
],
label
=
'out'
),
1
,
stride
=
ValueChoice
([
1
,
2
],
label
=
'stride'
))
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
3
,
'stride'
:
2
},
torch
.
randn
(
2
,
3
,
10
,
10
)).
size
(
2
)
==
5
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
3
,
'stride'
:
1
},
torch
.
randn
(
2
,
3
,
10
,
10
)).
size
(
2
)
==
10
# groups, dw conv
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
1
,
groups
=
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
))
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
6
},
torch
.
randn
(
2
,
6
,
10
,
10
)).
size
()
==
torch
.
Size
([
2
,
6
,
10
,
10
])
# make sure kernel is sliced correctly
conv
=
Conv2d
(
1
,
1
,
ValueChoice
([
1
,
3
],
label
=
'k'
),
bias
=
False
)
conv
=
MixedConv2d
.
mutate
(
conv
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
with
torch
.
no_grad
():
conv
.
weight
.
zero_
()
# only center is 1, must pick center to pass this test
conv
.
weight
[
0
,
0
,
1
,
1
]
=
1
conv
.
resample
({
'k'
:
1
})
assert
conv
(
torch
.
ones
((
1
,
1
,
3
,
3
))).
sum
().
item
()
==
9
def
test_mixed_batchnorm2d
():
bn
=
BatchNorm2d
(
ValueChoice
([
32
,
64
],
label
=
'dim'
))
assert
_mixed_operation_sampling_sanity_check
(
bn
,
{
'dim'
:
32
},
torch
.
randn
(
2
,
32
,
3
,
3
)).
size
(
1
)
==
32
assert
_mixed_operation_sampling_sanity_check
(
bn
,
{
'dim'
:
64
},
torch
.
randn
(
2
,
64
,
3
,
3
)).
size
(
1
)
==
64
_mixed_operation_differentiable_sanity_check
(
bn
,
torch
.
randn
(
2
,
64
,
3
,
3
))
def
test_mixed_mhattn
():
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
)
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
))[
0
].
size
(
-
1
)
==
4
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
8
},
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
))[
0
].
size
(
-
1
)
==
8
_mixed_operation_differentiable_sanity_check
(
mhattn
,
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
))
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
ValueChoice
([
2
,
3
,
4
],
label
=
'heads'
))
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'heads'
:
2
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
))[
0
].
size
(
-
1
)
==
4
with
pytest
.
raises
(
AssertionError
,
match
=
'divisible'
):
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'heads'
:
3
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
))[
0
].
size
(
-
1
)
==
4
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
,
kdim
=
ValueChoice
([
5
,
7
],
label
=
'kdim'
))
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'kdim'
:
7
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
7
),
torch
.
randn
(
7
,
2
,
4
))[
0
].
size
(
-
1
)
==
4
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
8
,
'kdim'
:
5
},
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
5
),
torch
.
randn
(
7
,
2
,
8
))[
0
].
size
(
-
1
)
==
8
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
,
vdim
=
ValueChoice
([
5
,
8
],
label
=
'vdim'
))
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'vdim'
:
8
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
8
))[
0
].
size
(
-
1
)
==
4
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
8
,
'vdim'
:
5
},
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
5
))[
0
].
size
(
-
1
)
==
8
_mixed_operation_differentiable_sanity_check
(
mhattn
,
torch
.
randn
(
5
,
3
,
8
),
torch
.
randn
(
5
,
3
,
8
),
torch
.
randn
(
5
,
3
,
8
))
@
pytest
.
mark
.
skipif
(
torch
.
__version__
.
startswith
(
'1.7'
),
reason
=
'batch_first is not supported for legacy PyTorch'
)
def
test_mixed_mhattn_batch_first
():
# batch_first is not supported for legacy pytorch versions
# mark 1.7 because 1.7 is used on legacy pipeline
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
2
,
kdim
=
(
ValueChoice
([
3
,
7
],
label
=
'kdim'
)),
vdim
=
ValueChoice
([
5
,
8
],
label
=
'vdim'
),
bias
=
False
,
add_bias_kv
=
True
,
batch_first
=
True
)
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'kdim'
:
7
,
'vdim'
:
8
},
torch
.
randn
(
2
,
7
,
4
),
torch
.
randn
(
2
,
7
,
7
),
torch
.
randn
(
2
,
7
,
8
))[
0
].
size
(
-
1
)
==
4
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
8
,
'kdim'
:
3
,
'vdim'
:
5
},
torch
.
randn
(
2
,
7
,
8
),
torch
.
randn
(
2
,
7
,
3
),
torch
.
randn
(
2
,
7
,
5
))[
0
].
size
(
-
1
)
==
8
_mixed_operation_differentiable_sanity_check
(
mhattn
,
torch
.
randn
(
1
,
7
,
8
),
torch
.
randn
(
1
,
7
,
7
),
torch
.
randn
(
1
,
7
,
8
))
def
test_pathsampling_layer_input
():
op
=
PathSamplingLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
label
=
'ccc'
)
with
pytest
.
raises
(
RuntimeError
,
match
=
'sample'
):
op
(
torch
.
randn
(
4
,
2
))
op
.
resample
({})
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
search_space_spec
()[
'ccc'
].
values
==
[
'a'
,
'b'
]
assert
op
.
export
({})[
'ccc'
]
in
[
'a'
,
'b'
]
input
=
PathSamplingInput
(
5
,
2
,
'concat'
,
'ddd'
)
sample
=
input
.
resample
({})
assert
'ddd'
in
sample
assert
len
(
sample
[
'ddd'
])
==
2
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
(
-
1
)
==
4
assert
len
(
input
.
export
({})[
'ddd'
])
==
2
def
test_differentiable_layer_input
():
op
=
DifferentiableMixedLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
export
({})[
'eee'
]
in
[
'a'
,
'b'
]
assert
len
(
list
(
op
.
parameters
()))
==
3
input
=
DifferentiableMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
(
-
1
)
==
2
assert
len
(
input
.
export
({})[
'ddd'
])
==
2
def
test_proxyless_layer_input
():
op
=
ProxylessMixedLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
assert
op
.
resample
({})[
'eee'
]
in
[
'a'
,
'b'
]
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
export
({})[
'eee'
]
in
[
'a'
,
'b'
]
assert
len
(
list
(
op
.
parameters
()))
==
3
input
=
ProxylessMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
.
resample
({})[
'ddd'
]
in
list
(
range
(
5
))
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
()
==
torch
.
Size
([
4
,
2
])
assert
input
.
export
({})[
'ddd'
]
in
list
(
range
(
5
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment