Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a0fd0036
Unverified
Commit
a0fd0036
authored
Aug 01, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 01, 2022
Browse files
Merge pull request #5036 from microsoft/promote-retiarii-to-nas
[DO NOT SQUASH] Promote retiarii to NAS
parents
d6dcb483
bc6d8796
Changes
239
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
40 additions
and
5033 deletions
+40
-5033
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+2
-112
nni/retiarii/nn/tensorflow/api.py
nni/retiarii/nn/tensorflow/api.py
+2
-7
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+2
-526
nni/retiarii/oneshot/pytorch/dataloader.py
nni/retiarii/oneshot/pytorch/dataloader.py
+2
-75
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+2
-252
nni/retiarii/oneshot/pytorch/enas.py
nni/retiarii/oneshot/pytorch/enas.py
+2
-144
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+2
-259
nni/retiarii/oneshot/pytorch/strategy.py
nni/retiarii/oneshot/pytorch/strategy.py
+2
-142
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
+2
-289
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
+2
-258
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
...etiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
+2
-240
nni/retiarii/oneshot/pytorch/supermodule/base.py
nni/retiarii/oneshot/pytorch/supermodule/base.py
+2
-91
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
+2
-536
nni/retiarii/oneshot/pytorch/supermodule/operation.py
nni/retiarii/oneshot/pytorch/supermodule/operation.py
+2
-701
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
+2
-193
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
+2
-412
nni/retiarii/oneshot/pytorch/utils.py
nni/retiarii/oneshot/pytorch/utils.py
+2
-3
nni/retiarii/operation.py
nni/retiarii/operation.py
+2
-242
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+2
-7
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+2
-544
No files found.
nni/retiarii/nn/pytorch/nn.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
pathlib
import
Path
# pylint: disable=wildcard-import,unused-wildcard-import
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path
=
Path
(
__file__
).
parent
/
'_nn.py'
# Update this when cache format changes, to enforce an update.
cache_version
=
2
def
validate_cache
()
->
bool
:
import
torch
cache_valid
=
[]
if
nn_cache_file_path
.
exists
():
lines
=
nn_cache_file_path
.
read_text
().
splitlines
()
for
line
in
lines
:
if
line
.
startswith
(
'# _torch_version'
):
_cached_torch_version
=
line
[
line
.
find
(
'='
)
+
1
:].
strip
()
if
_cached_torch_version
==
torch
.
__version__
:
cache_valid
.
append
(
True
)
if
line
.
startswith
(
'# _torch_nn_cache_version'
):
_cached_cache_version
=
int
(
line
[
line
.
find
(
'='
)
+
1
:].
strip
())
if
_cached_cache_version
==
cache_version
:
cache_valid
.
append
(
True
)
return
len
(
cache_valid
)
>=
2
and
all
(
cache_valid
)
def
generate_stub_file
()
->
str
:
import
inspect
import
warnings
import
torch
import
torch.nn
as
nn
_NO_WRAP_CLASSES
=
[
# not an nn.Module
'Parameter'
,
'ParameterList'
,
'UninitializedBuffer'
,
'UninitializedParameter'
,
# arguments are special
'Module'
,
'Sequential'
,
# utilities
'Container'
,
'DataParallel'
,
]
_WRAP_WITHOUT_TAG_CLASSES
=
[
# special support on graph engine
'ModuleList'
,
'ModuleDict'
,
]
code
=
[
'# Copyright (c) Microsoft Corporation.'
,
'# Licensed under the MIT license.'
,
'# This file is auto-generated to make auto-completion work.'
,
'# When pytorch version does not match, it will get automatically updated.'
,
'# pylint: skip-file'
,
'# pyright: reportGeneralTypeIssues=false'
,
f
'# _torch_version =
{
torch
.
__version__
}
'
,
f
'# _torch_nn_cache_version =
{
cache_version
}
'
,
'import typing'
,
'import torch.nn as nn'
,
'from nni.retiarii.serializer import basic_unit'
,
]
all_names
=
[]
# Add modules, classes, functions in torch.nn into this module.
for
name
,
obj
in
inspect
.
getmembers
(
torch
.
nn
):
if
inspect
.
isclass
(
obj
):
if
name
in
_NO_WRAP_CLASSES
:
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
elif
not
issubclass
(
obj
,
nn
.
Module
):
# It should never go here
# We did it to play safe
warnings
.
warn
(
f
'
{
obj
}
is found to be not a nn.Module, which is unexpected. '
'It means your PyTorch version might not be supported.'
,
RuntimeWarning
)
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
elif
name
in
_WRAP_WITHOUT_TAG_CLASSES
:
code
.
append
(
f
'
{
name
}
= typing.cast(typing.Type[nn.
{
name
}
], basic_unit(nn.
{
name
}
, basic_unit_tag=False))'
)
else
:
code
.
append
(
f
'
{
name
}
= typing.cast(typing.Type[nn.
{
name
}
], basic_unit(nn.
{
name
}
))'
)
all_names
.
append
(
name
)
elif
inspect
.
isfunction
(
obj
)
or
inspect
.
ismodule
(
obj
):
code
.
append
(
f
'
{
name
}
= nn.
{
name
}
'
)
# no modification
all_names
.
append
(
name
)
code
.
append
(
f
'__all__ =
{
all_names
}
'
)
return
'
\n
'
.
join
(
code
)
def
write_cache
(
code
:
str
)
->
None
:
with
nn_cache_file_path
.
open
(
'w'
)
as
fp
:
fp
.
write
(
code
)
code
=
generate_stub_file
()
if
not
validate_cache
():
write_cache
(
code
)
del
Path
,
validate_cache
,
write_cache
,
cache_version
,
nn_cache_file_path
,
code
from
._nn
import
*
# pylint: disable=import-error, wildcard-import, unused-wildcard-import
from
nni.nas.nn.pytorch.layers
import
*
nni/retiarii/nn/tensorflow/api.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import
tensorflow
as
tf
class
LayerChoice
(
tf
.
keras
.
Layer
):
# FIXME: This is only a draft to test multi-framework support, it's not unimplemented at all.
pass
from
nni.nas.nn.tensorflow.api
import
*
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import
warnings
from
itertools
import
chain
from
typing
import
Callable
,
Any
,
Dict
,
Union
,
Tuple
,
List
,
cast
import
pytorch_lightning
as
pl
import
torch.optim
as
optim
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
import
nni.retiarii.nn.pytorch
as
nas_nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
is_traceable
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.typehint
import
Literal
from
.supermodule.base
import
BaseSuperNetModule
__all__
=
[
'MutationHook'
,
'BaseSuperNetModule'
,
'BaseOneShotLightningModule'
,
'traverse_and_mutate_submodules'
,
'no_default_hook'
]
MutationHook
=
Callable
[[
nn
.
Module
,
str
,
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]],
Union
[
nn
.
Module
,
bool
,
Tuple
[
nn
.
Module
,
bool
]]]
def
traverse_and_mutate_submodules
(
root_module
:
nn
.
Module
,
hooks
:
list
[
MutationHook
],
mutate_kwargs
:
dict
[
str
,
Any
],
topdown
:
bool
=
True
)
->
list
[
BaseSuperNetModule
]:
"""
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
Parameters
----------
root_module : nn.Module
User-defined model space.
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
it's usually a ``pytorch_lightning.LightningModule``.
The mutation will be in-place on ``root_module``.
hooks : list[MutationHook]
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict
Extra keyword arguments passed to hooks.
topdown : bool, default = False
If topdown is true, hooks are first called, before traversing its sub-module (i.e., pre-order DFS).
Otherwise, sub-modules are first traversed, before calling hooks on this node (i.e., post-order DFS).
Returns
----------
modules : dict[str, nn.Module]
The replace result.
"""
memo
=
{}
module_list
=
[]
def
apply
(
m
):
# Need to call list() here because the loop body might replace some children in-place.
for
name
,
child
in
list
(
m
.
named_children
()):
# post-order DFS
if
not
topdown
:
apply
(
child
)
mutate_result
=
None
for
hook
in
hooks
:
hook_suggest
=
hook
(
child
,
name
,
memo
,
mutate_kwargs
)
# parse the mutate result
if
isinstance
(
hook_suggest
,
tuple
):
hook_suggest
,
suppress
=
hook_suggest
elif
hook_suggest
is
True
:
hook_suggest
,
suppress
=
None
,
True
elif
not
hook_suggest
:
# none / false
hook_suggest
,
suppress
=
None
,
False
elif
isinstance
(
hook_suggest
,
nn
.
Module
):
suppress
=
True
else
:
raise
TypeError
(
f
'Mutation hook returned
{
hook_suggest
}
of unsupported type:
{
type
(
hook_suggest
)
}
.'
)
if
hook_suggest
is
not
None
:
if
not
isinstance
(
hook_suggest
,
BaseSuperNetModule
):
warnings
.
warn
(
"Mutation hook didn't return a BaseSuperNetModule. It will be ignored in hooked module list."
,
RuntimeWarning
)
setattr
(
m
,
name
,
hook_suggest
)
mutate_result
=
hook_suggest
# if suppress, no further mutation hooks are called
if
suppress
:
break
if
isinstance
(
mutate_result
,
BaseSuperNetModule
):
# Replace child with the mutate result, and DFS this one
child
=
mutate_result
module_list
.
append
(
mutate_result
)
# pre-order DFS
if
topdown
:
apply
(
child
)
apply
(
root_module
)
return
module_list
def
no_default_hook
(
module
:
nn
.
Module
,
name
:
str
,
memo
:
dict
[
str
,
Any
],
mutate_kwargs
:
dict
[
str
,
Any
])
->
bool
:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet
primitive_list
=
(
nas_nn
.
LayerChoice
,
nas_nn
.
InputChoice
,
nas_nn
.
Repeat
,
nas_nn
.
NasBench101Cell
,
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
)
if
isinstance
(
module
,
primitive_list
):
raise
TypeError
(
f
'
{
type
(
module
).
__name__
}
is not supported'
)
if
isinstance
(
module
,
nas_nn
.
Cell
)
and
module
.
merge_op
!=
'all'
:
# need output_node_indices, which depends on super-net
raise
TypeError
(
f
'Cell with merge_op `
{
module
.
merge_op
}
` is not supported'
)
if
is_traceable
(
module
):
# check whether there is a value-choice in its arguments
has_valuechoice
=
False
for
arg
in
chain
(
cast
(
list
,
module
.
trace_args
),
cast
(
dict
,
module
.
trace_kwargs
).
values
()):
if
isinstance
(
arg
,
ValueChoiceX
):
has_valuechoice
=
True
break
if
has_valuechoice
:
raise
TypeError
(
f
'`basic_unit`
{
type
(
module
).
__name__
}
with value choice in its arguments is not supported. '
'Please try to remove `basic_unit` to see if that works, or support this type with value choice manually.'
)
return
True
# suppress all other hooks
class
BaseOneShotLightningModule
(
pl
.
LightningModule
):
_mutation_hooks_note
=
"""mutation_hooks : list[MutationHook]
Extra mutation hooks to support customized mutation on primitives other than built-ins.
Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
stored in :attr:`nas_modules`, and be the focus of the NAS algorithm.
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are four arguments:
1. a module that might be processed,
2. name of the module in its parent module,
3. a memo dict whose usage depends on the particular algorithm.
4. keyword arguments (configurations).
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.
The returned arguments can be also one of the three kinds:
1. tuple of: :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
2. boolean,
3. :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the following hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false.
If a none value appears on the place of
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
it means the hook suggests to
keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list.
"""
_inner_module_note
=
"""inner_module : pytorch_lightning.LightningModule
It's a `LightningModule <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html>`__
that defines computations, train/val loops, optimizers in a single class.
When used in NNI, the ``inner_module`` is the combination of instances of evaluator + base model
(to be precise, a base model wrapped with LightningModule in evaluator).
"""
__doc__
=
"""
The base class for all one-shot NAS modules.
In NNI, we try to separate the "search" part and "training" part in one-shot NAS.
The "training" part is defined with evaluator interface (has to be lightning evaluator interface to work with oneshot).
Since the lightning evaluator has already broken down the training into minimal building blocks,
we can re-assemble them after combining them with the "search" part of a particular algorithm.
After the re-assembling, this module has defined all the search + training. The experiment can use a lightning trainer
(which is another part in the evaluator) to train this module, so as to complete the search process.
Essential function such as preprocessing user's model, redirecting lightning hooks for user's model,
configuring optimizers and exporting NAS result are implemented in this class.
Attributes
----------
nas_modules : list[BaseSuperNetModule]
Modules that have been mutated, which the search algorithms should care about.
model : pl.LightningModule
PyTorch lightning module. A model space with training recipe defined (wrapped by LightningModule in evaluator).
Parameters
----------
"""
+
_inner_module_note
+
_mutation_hooks_note
trainer
:
pl
.
Trainer
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
False
def
default_mutation_hooks
(
self
)
->
list
[
MutationHook
]:
"""Override this to define class-default mutation hooks."""
return
[
no_default_hook
]
def
mutate_kwargs
(
self
)
->
dict
[
str
,
Any
]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return
{}
def
__init__
(
self
,
model
:
pl
.
LightningModule
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
):
super
().
__init__
()
assert
isinstance
(
model
,
pl
.
LightningModule
)
self
.
model
=
model
# append the default hooks
mutation_hooks
=
(
mutation_hooks
or
[])
+
self
.
default_mutation_hooks
()
# traverse the model, calling hooks on every submodule
self
.
nas_modules
:
list
[
BaseSuperNetModule
]
=
traverse_and_mutate_submodules
(
self
.
model
,
mutation_hooks
,
self
.
mutate_kwargs
(),
topdown
=
True
)
def
search_space_spec
(
self
)
->
dict
[
str
,
ParameterSpec
]:
"""Get the search space specification from :attr:`nas_modules`.
Returns
-------
dict
Key is the name of the choice, value is the corresponding :class:`ParameterSpec`.
"""
result
=
{}
for
module
in
self
.
nas_modules
:
result
.
update
(
module
.
search_space_spec
())
return
result
def
resample
(
self
)
->
dict
[
str
,
Any
]:
"""Trigger the resample for each :attr:`nas_modules`.
Sometimes (e.g., in differentiable cases), it does nothing.
Returns
-------
dict
Sampled architecture.
"""
result
=
{}
for
module
in
self
.
nas_modules
:
result
.
update
(
module
.
resample
(
memo
=
result
))
return
result
def
export
(
self
)
->
dict
[
str
,
Any
]:
"""
Export the NAS result, ideally the best choice of each :attr:`nas_modules`.
You may implement an ``export`` method for your customized :attr:`nas_modules`.
Returns
--------
dict
Keys are names of ``nas_modules``, and values are the choice indices of them.
"""
result
=
{}
for
module
in
self
.
nas_modules
:
result
.
update
(
module
.
export
(
memo
=
result
))
return
result
def
forward
(
self
,
x
):
return
self
.
model
(
x
)
def
training_step
(
self
,
batch
,
batch_idx
):
"""This is the implementation of what happens in training loops of one-shot algos.
It usually calls ``self.model.training_step`` which implements the real training recipe of the users' model.
"""
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
def
configure_optimizers
(
self
):
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.
For now :attr:`model` is tested against evaluators in :mod:`nni.retiarii.evaluator.pytorch.lightning`
and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented.
"""
# pylint: disable=assignment-from-none
arc_optimizers
=
self
.
configure_architecture_optimizers
()
if
arc_optimizers
is
None
:
return
self
.
model
.
configure_optimizers
()
if
isinstance
(
arc_optimizers
,
optim
.
Optimizer
):
arc_optimizers
=
[
arc_optimizers
]
self
.
arc_optim_count
=
len
(
arc_optimizers
)
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
try
:
# above v1.6
from
pytorch_lightning.core.optimizer
import
(
# pylint: disable=import-error
_configure_optimizers
,
# type: ignore
_configure_schedulers_automatic_opt
,
# type: ignore
_configure_schedulers_manual_opt
# type: ignore
)
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
# type: ignore
lr_schedulers
=
(
_configure_schedulers_automatic_opt
(
lr_schedulers
,
monitor
)
if
self
.
automatic_optimization
else
_configure_schedulers_manual_opt
(
lr_schedulers
)
)
except
ImportError
:
# under v1.5
w_optimizers
,
lr_schedulers
,
self
.
frequencies
,
monitor
=
\
self
.
trainer
.
_configure_optimizers
(
self
.
model
.
configure_optimizers
())
# type: ignore
lr_schedulers
=
self
.
trainer
.
_configure_schedulers
(
lr_schedulers
,
monitor
,
not
self
.
automatic_optimization
)
# type: ignore
if
any
(
sch
[
"scheduler"
].
optimizer
not
in
w_optimizers
for
sch
in
lr_schedulers
):
# type: ignore
raise
Exception
(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
# variables used to handle optimizer frequency
self
.
cur_optimizer_step
=
0
self
.
cur_optimizer_index
=
0
return
arc_optimizers
+
w_optimizers
,
lr_schedulers
def
on_train_start
(
self
):
return
self
.
model
.
on_train_start
()
def
on_train_end
(
self
):
return
self
.
model
.
on_train_end
()
def
on_fit_start
(
self
):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self
.
model
.
trainer
=
self
.
trainer
# type: ignore
self
.
model
.
log
=
self
.
log
return
self
.
model
.
on_fit_start
()
def
on_fit_end
(
self
):
return
self
.
model
.
on_fit_end
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
unused
)
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
unused
=
0
):
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
unused
)
# Deprecated hooks in pytorch-lightning
def
on_epoch_start
(
self
):
return
self
.
model
.
on_epoch_start
()
def
on_epoch_end
(
self
):
return
self
.
model
.
on_epoch_end
()
def
on_train_epoch_start
(
self
):
return
self
.
model
.
on_train_epoch_start
()
def
on_train_epoch_end
(
self
):
return
self
.
model
.
on_train_epoch_end
()
def
on_before_backward
(
self
,
loss
):
return
self
.
model
.
on_before_backward
(
loss
)
def
on_after_backward
(
self
):
return
self
.
model
.
on_after_backward
()
def
configure_gradient_clipping
(
self
,
optimizer
,
optimizer_idx
,
gradient_clip_val
=
None
,
gradient_clip_algorithm
=
None
):
return
self
.
model
.
configure_gradient_clipping
(
optimizer
,
optimizer_idx
,
gradient_clip_val
,
gradient_clip_algorithm
)
def
configure_architecture_optimizers
(
self
):
"""
Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here
if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers.
Returns
----------
arc_optimizers : list[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return
None
def
call_lr_schedulers
(
self
,
batch_index
):
"""
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
by this class, you can use this function to make schedulers behave as they were automatically handled by the lightning trainer.
Parameters
----------
batch_idx : int
batch index
"""
def
apply
(
lr_scheduler
):
# single scheduler is called every epoch
if
isinstance
(
lr_scheduler
,
_LRScheduler
):
if
self
.
trainer
.
is_last_batch
:
lr_scheduler
.
step
()
# lr_scheduler_config is called as configured
elif
isinstance
(
lr_scheduler
,
dict
):
interval
=
lr_scheduler
[
'interval'
]
frequency
=
lr_scheduler
[
'frequency'
]
if
(
interval
==
'step'
and
batch_index
%
frequency
==
0
)
or
\
(
interval
==
'epoch'
and
self
.
trainer
.
is_last_batch
and
(
self
.
trainer
.
current_epoch
+
1
)
%
frequency
==
0
):
lr_scheduler
[
'scheduler'
].
step
()
lr_schedulers
=
self
.
lr_schedulers
()
if
isinstance
(
lr_schedulers
,
list
):
for
lr_scheduler
in
lr_schedulers
:
apply
(
lr_scheduler
)
else
:
apply
(
lr_schedulers
)
def
call_weight_optimizers
(
self
,
method
:
Literal
[
'step'
,
'zero_grad'
]):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Parameters
----------
method : str
Method to call. Only ``step`` and ``zero_grad`` are supported now.
"""
def
apply_method
(
optimizer
,
method
):
if
method
==
'step'
:
optimizer
.
step
()
elif
method
==
'zero_grad'
:
optimizer
.
zero_grad
()
optimizers
=
self
.
weight_optimizers
()
if
optimizers
is
None
:
return
assert
isinstance
(
optimizers
,
list
),
'Did you forget to set use_pl_optimizers to true?'
if
len
(
self
.
frequencies
)
>
0
:
self
.
cur_optimizer_step
+=
1
if
self
.
frequencies
[
self
.
cur_optimizer_index
]
==
self
.
cur_optimizer_step
:
self
.
cur_optimizer_step
=
0
self
.
cur_optimizer_index
=
self
.
cur_optimizer_index
+
1
\
if
self
.
cur_optimizer_index
+
1
<
len
(
optimizers
)
\
else
0
apply_method
(
optimizers
[
self
.
cur_optimizer_index
],
method
)
else
:
for
optimizer
in
optimizers
:
apply_method
(
optimizer
,
method
)
def
architecture_optimizers
(
self
)
->
list
[
Optimizer
]
|
Optimizer
|
None
:
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in :meth:`training_step`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Architecture optimizers defined in :meth:`configure_architecture_optimizers`. This will be None if there is no
architecture optimizers.
"""
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
arc_opts
=
opts
[:
self
.
arc_optim_count
]
if
len
(
arc_opts
)
==
1
:
return
cast
(
Optimizer
,
arc_opts
[
0
])
return
cast
(
List
[
Optimizer
],
arc_opts
)
# If there is only 1 optimizer and it is the architecture optimizer
if
self
.
arc_optim_count
==
1
:
return
cast
(
Union
[
List
[
Optimizer
],
Optimizer
],
opts
)
return
None
def
weight_optimizers
(
self
)
->
list
[
Optimizer
]
|
Optimizer
|
None
:
"""
Get user optimizers from all optimizers. Use this to get user optimizers in :meth:`training_step`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts
=
self
.
optimizers
()
if
isinstance
(
opts
,
list
):
# pylint: disable=unsubscriptable-object
return
cast
(
List
[
Optimizer
],
opts
[
self
.
arc_optim_count
:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer
if
self
.
arc_optim_count
==
0
:
return
cast
(
Union
[
List
[
Optimizer
],
Optimizer
],
opts
)
return
None
from
nni.nas.oneshot.pytorch.base_lightning
import
*
nni/retiarii/oneshot/pytorch/dataloader.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
# pylint: disable=wildcard-import,unused-wildcard-import
from
typing
import
Any
from
pytorch_lightning.trainer.supporters
import
CombinedLoader
,
CombinedLoaderIterator
__all__
=
[
'ConcatLoader'
]
class
ConcatLoader
(
CombinedLoader
):
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders
instead of loading them in parallel.
Parameters
----------
loaders
For example, ::
{
"train": DataLoader(train_dataset),
"val": DataLoader(val_dataset)
}
In this example, the loader will first produce the batches from "train", then "val".
mode
Only support "min_size" for now.
"""
def
__init__
(
self
,
loaders
:
dict
[
str
,
Any
],
mode
:
str
=
'min_size'
):
# FIXME: max_cycle will make dataloaders cycle iterators,
# causing extra problems.
if
mode
!=
'min_size'
:
raise
ValueError
(
'Only min_size mode is supported now.'
)
super
().
__init__
(
loaders
,
mode
)
def
__iter__
(
self
)
->
Any
:
"""Replace the super-class iterator with ours."""
self
.
_try_to_patch_pytorch_dataloader
()
iterator
=
ConcatLoaderIterator
(
self
.
loaders
)
# handle fault tolerant restart.
self
.
on_restart
(
iterator
)
self
.
_iterator
=
iterator
return
iterator
@
staticmethod
def
_try_to_patch_pytorch_dataloader
():
"""Copied from CombinedLoader."""
from
torch.utils.data.dataloader
import
_BaseDataLoaderIter
# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def
__getstate__patch__
(
*
_
):
return
{}
_BaseDataLoaderIter
.
__getstate__
=
__getstate__patch__
# type: ignore
def
__len__
(
self
)
->
int
:
return
int
(
sum
(
self
.
_calc_num_batches
(
loader
)
for
loader
in
self
.
loaders
.
values
()))
class
ConcatLoaderIterator
(
CombinedLoaderIterator
):
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner."""
def
__next__
(
self
)
->
Any
:
"""Fetches the next batch from multiple data loaders,
by looking for the first iterator that isn't exhausted yet.
"""
if
not
len
(
self
.
loader_iters
)
==
len
(
self
.
loaders
):
raise
RuntimeError
(
'loader_iters must have the same length as loaders.'
)
for
i
,
(
loader_name
,
iterator
)
in
enumerate
(
self
.
loader_iters
.
items
()):
try
:
return
(
self
.
request_next_batch
(
iterator
),
loader_name
)
except
StopIteration
:
if
i
+
1
==
len
(
self
.
loader_iters
):
raise
from
nni.nas.oneshot.pytorch.dataloader
import
*
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Experimental version of differentiable one-shot implementation."""
# pylint: disable=wildcard-import,unused-wildcard-import
from
__future__
import
annotations
import
pytorch_lightning
as
pl
import
torch
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.supermodule.differentiable
import
(
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
MixedOpDifferentiablePolicy
,
GumbelSoftmax
,
DifferentiableMixedCell
,
DifferentiableMixedRepeat
)
from
.supermodule.proxyless
import
ProxylessMixedInput
,
ProxylessMixedLayer
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
,
NATIVE_SUPPORTED_OP_NAMES
class
DartsLightningModule
(
BaseOneShotLightningModule
):
_darts_note
=
"""
Continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.
`Reference <https://arxiv.org/abs/1806.09055>`__.
DARTS algorithm is one of the most fundamental one-shot algorithm.
DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet.
.. versionadded:: 2.8
Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
One difference is that, in DARTS, we are using Softmax instead of GumbelSoftmax.
The supported mutation primitives of DARTS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
)
__doc__
=
_darts_note
.
format
(
module_notes
=
'The DARTS Module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
def
default_mutation_hooks
(
self
)
->
list
[
MutationHook
]:
"""Replace modules with differentiable versions"""
hooks
=
[
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedInput
.
mutate
,
DifferentiableMixedCell
.
mutate
,
DifferentiableMixedRepeat
.
mutate
,
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
.
append
(
no_default_hook
)
return
hooks
def
mutate_kwargs
(
self
):
"""Use differentiable strategy for mixed operations."""
return
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
}
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0E-4
):
self
.
arc_learning_rate
=
arc_learning_rate
super
().
__init__
(
inner_module
,
mutation_hooks
=
mutation_hooks
)
def
training_step
(
self
,
batch
,
batch_idx
):
# grad manually
arc_optim
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_optim
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_optim to be a single Optimizer, but found:
{
arc_optim
}
'
)
# DARTS strategy makes sure that ``train`` and ``val`` must be in the batch
trn_batch
=
batch
[
'train'
]
val_batch
=
batch
[
'val'
]
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
# See code of those methods for details.
self
.
resample
()
arc_optim
.
zero_grad
()
arc_step_loss
=
self
.
model
.
training_step
(
val_batch
,
2
*
batch_idx
)
if
isinstance
(
arc_step_loss
,
dict
):
arc_step_loss
=
arc_step_loss
[
'loss'
]
self
.
manual_backward
(
arc_step_loss
)
self
.
finalize_grad
()
arc_optim
.
step
()
# phase 2: model step
self
.
resample
()
self
.
call_weight_optimizers
(
'zero_grad'
)
loss_and_metrics
=
self
.
model
.
training_step
(
trn_batch
,
2
*
batch_idx
+
1
)
w_step_loss
=
loss_and_metrics
[
'loss'
]
\
if
isinstance
(
loss_and_metrics
,
dict
)
else
loss_and_metrics
self
.
manual_backward
(
w_step_loss
)
self
.
call_weight_optimizers
(
'step'
)
self
.
call_lr_schedulers
(
batch_idx
)
return
loss_and_metrics
def
finalize_grad
(
self
):
# Note: This hook is currently kept for Proxyless NAS.
pass
def
configure_architecture_optimizers
(
self
):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params
=
[]
for
m
in
self
.
nas_modules
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
# type: ignore
ctrl_optim
=
torch
.
optim
.
Adam
(
list
(
set
(
ctrl_params
)),
3.e-4
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
return
ctrl_optim
class
ProxylessLightningModule
(
DartsLightningModule
):
_proxyless_note
=
"""
A low-memory-consuming optimized version of differentiable architecture search. See `reference <https://arxiv.org/abs/1812.00332>`__.
This is a DARTS-based method that resamples the architecture to reduce memory consumption.
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
The supported mutation primitives of Proxyless are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
__doc__
=
_proxyless_note
.
format
(
module_notes
=
'This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
def
default_mutation_hooks
(
self
)
->
list
[
MutationHook
]:
"""Replace modules with gumbel-differentiable versions"""
hooks
=
[
ProxylessMixedLayer
.
mutate
,
ProxylessMixedInput
.
mutate
,
no_default_hook
,
]
# FIXME: no support for mixed operation currently
return
hooks
def
finalize_grad
(
self
):
for
m
in
self
.
nas_modules
:
m
.
finalize_grad
()
# type: ignore
class
GumbelDartsLightningModule
(
DartsLightningModule
):
_gumbel_darts_note
=
"""
Choose the best block by using Gumbel Softmax random sampling and differentiable training.
See `FBNet <https://arxiv.org/abs/1812.03443>`__ and `SNAS <https://arxiv.org/abs/1812.09926>`__.
This is a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it tries to mimick the behavior of sampling one path on forward by gradually
cool down the temperature, aiming to bridge the gap between differentiable architecture weights and
discretization of architectures.
.. versionadded:: 2.8
Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
The supported mutation primitives of GumbelDARTS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
gumbel_temperature : float
The initial temperature used in gumbel-softmax.
use_temp_anneal : bool
If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
)
def
mutate_kwargs
(
self
):
"""Use gumbel softmax."""
return
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
,
'softmax'
:
GumbelSoftmax
(),
}
def
__init__
(
self
,
inner_module
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
,
arc_learning_rate
:
float
=
3.0e-4
,
gumbel_temperature
:
float
=
1.
,
use_temp_anneal
:
bool
=
False
,
min_temp
:
float
=
.
33
):
super
().
__init__
(
inner_module
,
mutation_hooks
,
arc_learning_rate
=
arc_learning_rate
)
self
.
temp
=
gumbel_temperature
self
.
init_temp
=
gumbel_temperature
self
.
use_temp_anneal
=
use_temp_anneal
self
.
min_temp
=
min_temp
def
on_train_epoch_end
(
self
):
if
self
.
use_temp_anneal
:
self
.
temp
=
(
1
-
self
.
trainer
.
current_epoch
/
self
.
trainer
.
max_epochs
)
*
(
self
.
init_temp
-
self
.
min_temp
)
+
self
.
min_temp
self
.
temp
=
max
(
self
.
temp
,
self
.
min_temp
)
for
module
in
self
.
nas_modules
:
if
hasattr
(
module
,
'_softmax'
):
module
.
_softmax
.
temp
=
self
.
temp
# type: ignore
return
self
.
model
.
on_train_epoch_end
()
from
nni.nas.oneshot.pytorch.differentiable
import
*
nni/retiarii/oneshot/pytorch/enas.py
View file @
a0fd0036
...
...
@@ -3,14 +3,14 @@
import
logging
import
warnings
from
typing
import
cast
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
from
torch.utils.data
import
SubsetRandomSampler
,
DataLoader
from
nni.nas.oneshot.pytorch.enas
import
ReinforceController
,
ReinforceField
from
..interface
import
BaseOneShotTrainer
from
.random
import
PathSamplingLayerChoice
,
PathSamplingInputChoice
from
.utils
import
AverageMeterGroup
,
replace_layer_choice
,
replace_input_choice
,
to_device
...
...
@@ -18,148 +18,6 @@ from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger
=
logging
.
getLogger
(
__name__
)
class
StackedLSTMCell
(
nn
.
Module
):
def
__init__
(
self
,
layers
,
size
,
bias
):
super
().
__init__
()
self
.
lstm_num_layers
=
layers
self
.
lstm_modules
=
nn
.
ModuleList
([
nn
.
LSTMCell
(
size
,
size
,
bias
=
bias
)
for
_
in
range
(
self
.
lstm_num_layers
)])
def
forward
(
self
,
inputs
,
hidden
):
prev_h
,
prev_c
=
hidden
next_h
,
next_c
=
[],
[]
for
i
,
m
in
enumerate
(
self
.
lstm_modules
):
curr_h
,
curr_c
=
m
(
inputs
,
(
prev_h
[
i
],
prev_c
[
i
]))
next_c
.
append
(
curr_c
)
next_h
.
append
(
curr_h
)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs
=
curr_h
[
-
1
].
view
(
1
,
-
1
)
return
next_h
,
next_c
class
ReinforceField
:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def
__init__
(
self
,
name
,
total
,
choose_one
):
self
.
name
=
name
self
.
total
=
total
self
.
choose_one
=
choose_one
def
__repr__
(
self
):
return
f
'ReinforceField(name=
{
self
.
name
}
, total=
{
self
.
total
}
, choose_one=
{
self
.
choose_one
}
)'
class
ReinforceController
(
nn
.
Module
):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def
__init__
(
self
,
fields
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
skip_target
=
0.4
,
temperature
=
None
,
entropy_reduction
=
'sum'
):
super
(
ReinforceController
,
self
).
__init__
()
self
.
fields
=
fields
self
.
lstm_size
=
lstm_size
self
.
lstm_num_layers
=
lstm_num_layers
self
.
tanh_constant
=
tanh_constant
self
.
temperature
=
temperature
self
.
skip_target
=
skip_target
self
.
lstm
=
StackedLSTMCell
(
self
.
lstm_num_layers
,
self
.
lstm_size
,
False
)
self
.
attn_anchor
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
attn_query
=
nn
.
Linear
(
self
.
lstm_size
,
self
.
lstm_size
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
# pylint: disable=not-callable
requires_grad
=
False
)
assert
entropy_reduction
in
[
'sum'
,
'mean'
],
'Entropy reduction must be one of sum and mean.'
self
.
entropy_reduction
=
torch
.
sum
if
entropy_reduction
==
'sum'
else
torch
.
mean
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
'none'
)
self
.
soft
=
nn
.
ModuleDict
({
field
.
name
:
nn
.
Linear
(
self
.
lstm_size
,
field
.
total
,
bias
=
False
)
for
field
in
fields
})
self
.
embedding
=
nn
.
ModuleDict
({
field
.
name
:
nn
.
Embedding
(
field
.
total
,
self
.
lstm_size
)
for
field
in
fields
})
def
resample
(
self
):
self
.
_initialize
()
result
=
dict
()
for
field
in
self
.
fields
:
result
[
field
.
name
]
=
self
.
_sample_single
(
field
)
return
result
def
_initialize
(
self
):
self
.
_inputs
=
self
.
g_emb
.
data
self
.
_c
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
dtype
=
self
.
_inputs
.
dtype
,
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
self
.
_h
=
[
torch
.
zeros
((
1
,
self
.
lstm_size
),
dtype
=
self
.
_inputs
.
dtype
,
device
=
self
.
_inputs
.
device
)
for
_
in
range
(
self
.
lstm_num_layers
)]
self
.
sample_log_prob
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
self
.
sample_entropy
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
self
.
sample_skip_penalty
:
torch
.
Tensor
=
cast
(
torch
.
Tensor
,
0
)
def
_lstm_next_step
(
self
):
self
.
_h
,
self
.
_c
=
self
.
lstm
(
self
.
_inputs
,
(
self
.
_h
,
self
.
_c
))
def
_sample_single
(
self
,
field
):
self
.
_lstm_next_step
()
logit
=
self
.
soft
[
field
.
name
](
self
.
_h
[
-
1
])
if
self
.
temperature
is
not
None
:
logit
/=
self
.
temperature
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
if
field
.
choose_one
:
sampled
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
sampled
)
self
.
_inputs
=
self
.
embedding
[
field
.
name
](
sampled
)
else
:
logit
=
logit
.
view
(
-
1
,
1
)
logit
=
torch
.
cat
([
-
logit
,
logit
],
1
)
# pylint: disable=invalid-unary-operand-type
sampled
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
skip_prob
=
torch
.
sigmoid
(
logit
)
kl
=
torch
.
sum
(
skip_prob
*
torch
.
log
(
skip_prob
/
self
.
skip_targets
))
self
.
sample_skip_penalty
+=
kl
log_prob
=
self
.
cross_entropy_loss
(
logit
,
sampled
)
sampled
=
sampled
.
nonzero
().
view
(
-
1
)
if
sampled
.
sum
().
item
():
self
.
_inputs
=
(
torch
.
sum
(
self
.
embedding
[
field
.
name
](
sampled
.
view
(
-
1
)),
0
)
/
(
1.
+
torch
.
sum
(
sampled
))).
unsqueeze
(
0
)
else
:
self
.
_inputs
=
torch
.
zeros
(
1
,
self
.
lstm_size
,
device
=
self
.
embedding
[
field
.
name
].
weight
.
device
)
# type: ignore
sampled
=
sampled
.
detach
().
cpu
().
numpy
().
tolist
()
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
if
len
(
sampled
)
==
1
:
sampled
=
sampled
[
0
]
return
sampled
class
EnasTrainer
(
BaseOneShotTrainer
):
"""
ENAS trainer.
...
...
nni/retiarii/oneshot/pytorch/sampling.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Experimental version of sampling-based one-shot implementation."""
# pylint: disable=wildcard-import,unused-wildcard-import
from
__future__
import
annotations
import
warnings
from
typing
import
Any
import
pytorch_lightning
as
pl
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
,
NATIVE_SUPPORTED_OP_NAMES
from
.supermodule.sampling
import
(
PathSamplingInput
,
PathSamplingLayer
,
MixedOpPathSamplingPolicy
,
PathSamplingCell
,
PathSamplingRepeat
)
from
.enas
import
ReinforceController
,
ReinforceField
class
RandomSamplingLightningModule
(
BaseOneShotLightningModule
):
_random_note
=
"""
Train a super-net with uniform path sampling. See `reference <https://arxiv.org/abs/1904.00420>`__.
In each epoch, model parameters are trained after a uniformly random sampling of each choice.
Notably, the exporting result is **also a random sample** of the search space.
The supported mutation primitives of RandomOneShot are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
Parameters
----------
{{module_params}}
{base_params}
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
)
__doc__
=
_random_note
.
format
(
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
# turn on automatic optimization because nothing interesting is going on here.
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
True
def
default_mutation_hooks
(
self
)
->
list
[
MutationHook
]:
"""Replace modules with differentiable versions"""
hooks
=
[
PathSamplingLayer
.
mutate
,
PathSamplingInput
.
mutate
,
PathSamplingRepeat
.
mutate
,
PathSamplingCell
.
mutate
,
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
.
append
(
no_default_hook
)
return
hooks
def
mutate_kwargs
(
self
):
"""Use path sampling strategy for mixed-operations."""
return
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
}
def
training_step
(
self
,
batch
,
batch_idx
):
self
.
resample
()
return
self
.
model
.
training_step
(
batch
,
batch_idx
)
def
export
(
self
)
->
dict
[
str
,
Any
]:
"""
Export of Random one-shot. It will return an arbitrary architecture.
"""
warnings
.
warn
(
'Direct export from RandomOneShot returns an arbitrary architecture. '
'Sampling the best architecture from this trained supernet is another search process. '
'Users need to do another search based on the checkpoint of the one-shot strategy.'
,
UserWarning
)
return
super
().
export
()
class
EnasLightningModule
(
RandomSamplingLightningModule
):
_enas_note
=
"""
RL controller learns to generate the best network on a super-net. See `ENAS paper <https://arxiv.org/abs/1802.03268>`__.
There are 2 steps in an epoch.
- Firstly, training model parameters.
- Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward.
.. note::
ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``.
See explanation of ``reward_metric_name`` for details.
The supported mutation primitives of ENAS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController`.
entropy_weight : float
Weight of sample entropy loss in RL.
skip_weight : float
Weight of skip penalty loss. See :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController` for details.
baseline_decay : float
Decay factor of reward baseline, which is used to normalize the reward in RL.
At each step, the new reward baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int
Number of steps for which the gradients will be accumulated,
before updating the weights of RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
If there are multiple, by default, it will find the metric with key name ``default``.
If reward_metric_name is specified, it will find reward_metric_name.
Otherwise it raises an exception indicating multiple metrics are found.
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
,
supported_ops
=
', '
.
join
(
NATIVE_SUPPORTED_OP_NAMES
)
)
__doc__
=
_enas_note
.
format
(
module_notes
=
'``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
@
property
def
automatic_optimization
(
self
)
->
bool
:
return
False
def
__init__
(
self
,
inner_module
:
pl
.
LightningModule
,
*
,
ctrl_kwargs
:
dict
[
str
,
Any
]
|
None
=
None
,
entropy_weight
:
float
=
1e-4
,
skip_weight
:
float
=
.
8
,
baseline_decay
:
float
=
.
999
,
ctrl_steps_aggregate
:
float
=
20
,
ctrl_grad_clip
:
float
=
0
,
reward_metric_name
:
str
|
None
=
None
,
mutation_hooks
:
list
[
MutationHook
]
|
None
=
None
):
super
().
__init__
(
inner_module
,
mutation_hooks
)
# convert parameter spec to legacy ReinforceField
# this part will be refactored
self
.
nas_fields
:
list
[
ReinforceField
]
=
[]
for
name
,
param_spec
in
self
.
search_space_spec
().
items
():
if
param_spec
.
chosen_size
not
in
(
1
,
None
):
raise
ValueError
(
'ENAS does not support n_chosen to be values other than 1 or None.'
)
self
.
nas_fields
.
append
(
ReinforceField
(
name
,
param_spec
.
size
,
param_spec
.
chosen_size
==
1
))
self
.
controller
=
ReinforceController
(
self
.
nas_fields
,
**
(
ctrl_kwargs
or
{}))
self
.
entropy_weight
=
entropy_weight
self
.
skip_weight
=
skip_weight
self
.
baseline_decay
=
baseline_decay
self
.
baseline
=
0.
self
.
ctrl_steps_aggregate
=
ctrl_steps_aggregate
self
.
ctrl_grad_clip
=
ctrl_grad_clip
self
.
reward_metric_name
=
reward_metric_name
def
configure_architecture_optimizers
(
self
):
return
optim
.
Adam
(
self
.
controller
.
parameters
(),
lr
=
3.5e-4
)
def
training_step
(
self
,
batch_packed
,
batch_idx
):
batch
,
mode
=
batch_packed
if
mode
==
'train'
:
# train model params
with
torch
.
no_grad
():
self
.
resample
()
self
.
call_weight_optimizers
(
'zero_grad'
)
step_output
=
self
.
model
.
training_step
(
batch
,
batch_idx
)
w_step_loss
=
step_output
[
'loss'
]
\
if
isinstance
(
step_output
,
dict
)
else
step_output
self
.
manual_backward
(
w_step_loss
)
self
.
call_weight_optimizers
(
'step'
)
else
:
# train ENAS agent
arc_opt
=
self
.
architecture_optimizers
()
if
not
isinstance
(
arc_opt
,
optim
.
Optimizer
):
raise
TypeError
(
f
'Expect arc_opt to be a single Optimizer, but found:
{
arc_opt
}
'
)
arc_opt
.
zero_grad
()
self
.
resample
()
step_output
=
self
.
model
.
validation_step
(
batch
,
batch_idx
)
# use the default metric of self.model as reward function
if
len
(
self
.
trainer
.
callback_metrics
)
==
1
:
_
,
metric
=
next
(
iter
(
self
.
trainer
.
callback_metrics
.
items
()))
else
:
metric_name
=
self
.
reward_metric_name
or
'default'
if
metric_name
not
in
self
.
trainer
.
callback_metrics
:
raise
KeyError
(
f
'Model reported metrics should contain a ``
{
metric_name
}
`` key but '
f
'found multiple (or zero) metrics without default:
{
list
(
self
.
trainer
.
callback_metrics
.
keys
())
}
. '
f
'Try to use self.log to report metrics with the specified key ``
{
metric_name
}
`` in validation_step, '
'and remember to set on_step=True.'
)
metric
=
self
.
trainer
.
callback_metrics
[
metric_name
]
reward
:
float
=
metric
.
item
()
if
self
.
entropy_weight
:
reward
=
reward
+
self
.
entropy_weight
*
self
.
controller
.
sample_entropy
.
item
()
# type: ignore
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
rnn_step_loss
=
self
.
controller
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
rnn_step_loss
=
rnn_step_loss
+
self
.
skip_weight
*
self
.
controller
.
sample_skip_penalty
rnn_step_loss
=
rnn_step_loss
/
self
.
ctrl_steps_aggregate
self
.
manual_backward
(
rnn_step_loss
)
if
(
batch_idx
+
1
)
%
self
.
ctrl_steps_aggregate
==
0
:
if
self
.
ctrl_grad_clip
>
0
:
nn
.
utils
.
clip_grad_norm_
(
self
.
controller
.
parameters
(),
self
.
ctrl_grad_clip
)
arc_opt
.
step
()
arc_opt
.
zero_grad
()
return
step_output
def
resample
(
self
):
"""Resample the architecture with ENAS controller."""
sample
=
self
.
controller
.
resample
()
result
=
self
.
_interpret_controller_sampling_result
(
sample
)
for
module
in
self
.
nas_modules
:
module
.
resample
(
memo
=
result
)
return
result
def
export
(
self
):
"""Run one more inference of ENAS controller."""
self
.
controller
.
eval
()
with
torch
.
no_grad
():
return
self
.
_interpret_controller_sampling_result
(
self
.
controller
.
resample
())
def
_interpret_controller_sampling_result
(
self
,
sample
:
dict
[
str
,
int
])
->
dict
[
str
,
Any
]:
"""Convert ``{label: index}`` to ``{label: name}``"""
space_spec
=
self
.
search_space_spec
()
for
key
in
list
(
sample
.
keys
()):
sample
[
key
]
=
space_spec
[
key
].
values
[
sample
[
key
]]
return
sample
from
nni.nas.oneshot.pytorch.sampling
import
*
nni/retiarii/oneshot/pytorch/strategy.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Strategy integration of one-shot.
# pylint: disable=wildcard-import,unused-wildcard-import
This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""
from
__future__
import
annotations
import
warnings
from
typing
import
Any
,
Type
import
torch.nn
as
nn
from
nni.retiarii.graph
import
Model
from
nni.retiarii.strategy.base
import
BaseStrategy
from
nni.retiarii.evaluator.pytorch.lightning
import
Lightning
,
LightningModule
from
.base_lightning
import
BaseOneShotLightningModule
from
.differentiable
import
DartsLightningModule
,
ProxylessLightningModule
,
GumbelDartsLightningModule
from
.sampling
import
EnasLightningModule
,
RandomSamplingLightningModule
class
OneShotStrategy
(
BaseStrategy
):
"""Wrap an one-shot lightning module as a one-shot strategy."""
def
__init__
(
self
,
oneshot_module
:
Type
[
BaseOneShotLightningModule
],
**
kwargs
):
self
.
oneshot_module
=
oneshot_module
self
.
oneshot_kwargs
=
kwargs
self
.
model
:
BaseOneShotLightningModule
|
None
=
None
def
preprocess_dataloader
(
self
,
train_dataloaders
:
Any
,
val_dataloaders
:
Any
)
->
tuple
[
Any
,
Any
]:
"""
One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way.
As one-shot strategy doesn't try to open the blackbox of a batch,
theoretically, these dataloader can be
`any dataloader types supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
Returns
-------
A tuple of preprocessed train dataloaders and validation dataloaders.
"""
return
train_dataloaders
,
val_dataloaders
def
run
(
self
,
base_model
:
Model
,
applied_mutators
):
# one-shot strategy doesn't use ``applied_mutators``
# but get the "mutators" on their own
_reason
=
'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
if
not
isinstance
(
base_model
.
python_object
,
nn
.
Module
):
raise
TypeError
(
'Model is not a nn.Module. '
+
_reason
)
py_model
:
nn
.
Module
=
base_model
.
python_object
if
applied_mutators
:
raise
ValueError
(
'Mutator is not empty. '
+
_reason
)
if
not
isinstance
(
base_model
.
evaluator
,
Lightning
):
raise
TypeError
(
'Evaluator needs to be a lightning evaluator to make one-shot strategy work.'
)
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
set_model
(
py_model
)
self
.
model
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
evaluator
:
Lightning
=
base_model
.
evaluator
if
evaluator
.
train_dataloaders
is
None
or
evaluator
.
val_dataloaders
is
None
:
raise
TypeError
(
'Training and validation dataloader are both required to set in evaluator for one-shot strategy.'
)
train_loader
,
val_loader
=
self
.
preprocess_dataloader
(
evaluator
.
train_dataloaders
,
evaluator
.
val_dataloaders
)
evaluator
.
trainer
.
fit
(
self
.
model
,
train_loader
,
val_loader
)
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
list
[
Any
]:
"""The behavior of export top models in strategy depends on the implementation of inner one-shot module."""
if
self
.
model
is
None
:
raise
RuntimeError
(
'One-shot strategy needs to be run before export.'
)
if
top_k
!=
1
:
warnings
.
warn
(
'One-shot strategy currently only supports exporting top-1 model.'
,
RuntimeWarning
)
return
[
self
.
model
.
export
()]
class
DARTS
(
OneShotStrategy
):
__doc__
=
DartsLightningModule
.
_darts_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
DartsLightningModule
,
**
kwargs
)
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
# By returning a dict, we make a CombinedLoader (in Lightning)
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
Proxyless
(
OneShotStrategy
):
__doc__
=
ProxylessLightningModule
.
_proxyless_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
ProxylessLightningModule
,
**
kwargs
)
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
GumbelDARTS
(
OneShotStrategy
):
__doc__
=
GumbelDartsLightningModule
.
_gumbel_darts_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
GumbelDartsLightningModule
,
**
kwargs
)
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
return
{
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
},
None
class
ENAS
(
OneShotStrategy
):
__doc__
=
EnasLightningModule
.
_enas_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
EnasLightningModule
,
**
kwargs
)
def
preprocess_dataloader
(
self
,
train_dataloaders
,
val_dataloaders
):
# Import locally to avoid import error on legacy PL version
from
.dataloader
import
ConcatLoader
return
ConcatLoader
({
'train'
:
train_dataloaders
,
'val'
:
val_dataloaders
}),
None
class
RandomOneShot
(
OneShotStrategy
):
__doc__
=
RandomSamplingLightningModule
.
_random_note
.
format
(
module_notes
=
''
,
module_params
=
''
)
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
RandomSamplingLightningModule
,
**
kwargs
)
from
nni.nas.oneshot.pytorch.strategy
import
*
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Thie file handles "slice" commonly used in mixed-operation.
# pylint: disable=wildcard-import,unused-wildcard-import
The ``slice_type`` we support here, is "slice" or "list of slice".
The reason is that sometimes (e.g., in multi-head attention),
the tensor slice could be from multiple parts. This type is extensible.
We can support arbitrary masks in future if we need them.
To slice a tensor, we need ``multidim_slice``,
which is simply a tuple consists of ``slice_type``.
Usually in python programs, the variable put into slice's start, stop and step
should be integers (or NoneType).
But in our case, it could also be a dict from integer to float,
representing a distribution of integer. When that happens,
we convert a "slice with some weighted values", to a "weighted slice".
To this end, we track the computation with ``MaybeWeighted``,
and replay the computation with each possible value.
Meanwhile, we record their weights.
Note that ``MaybeWeighted`` is also extensible.
We can support more types of objects on slice in future.
The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor.
"""
from
__future__
import
annotations
import
operator
from
typing
import
Callable
,
Iterator
,
TypeVar
,
Any
,
Optional
,
Tuple
,
Union
,
List
,
Dict
,
Generic
,
cast
import
numpy
as
np
import
torch
__all__
=
[
'slice_type'
,
'multidim_slice'
,
'scalar_or_scalar_dict'
,
'int_or_int_dict'
,
'zeros_like'
,
'Slicable'
,
'MaybeWeighted'
,
]
T
=
TypeVar
(
'T'
)
slice_type
=
Union
[
slice
,
List
[
slice
]]
multidim_slice
=
Tuple
[
slice_type
,
...]
scalar_or_scalar_dict
=
Union
[
T
,
Dict
[
T
,
float
]]
int_or_int_dict
=
scalar_or_scalar_dict
[
int
]
_value_fn_type
=
Optional
[
Callable
[[
int_or_int_dict
],
int
]]
def
zeros_like
(
arr
:
T
)
->
T
:
if
isinstance
(
arr
,
np
.
ndarray
):
return
np
.
zeros_like
(
arr
)
elif
isinstance
(
arr
,
torch
.
Tensor
):
return
torch
.
zeros_like
(
arr
)
else
:
raise
TypeError
(
f
'Unsupported type for
{
arr
}
:
{
type
(
arr
)
}
'
)
def
_eliminate_list_slice
(
shape
:
tuple
,
slice_
:
multidim_slice
)
->
multidim_slice
:
# get rid of list of slice
result
=
[]
for
i
in
range
(
len
(
slice_
)):
if
isinstance
(
slice_
[
i
],
list
):
# convert list of slices to mask
mask
=
np
.
zeros
(
shape
[
i
],
dtype
=
np
.
bool
)
# type: ignore
for
sl
in
cast
(
List
[
slice
],
slice_
[
i
]):
mask
[
sl
]
=
1
result
.
append
(
mask
)
else
:
result
.
append
(
slice_
[
i
])
return
tuple
(
result
)
def
_slice_weight
(
weight
:
T
,
slice_
:
multidim_slice
|
list
[
tuple
[
multidim_slice
,
float
]])
->
T
:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
if
isinstance
(
slice_
,
list
):
# for weighted case, we get the corresponding masks. e.g.,
# {([3:6],): 0.6, ([2:4],): 0.3} => [0, 0, 0.3, 0.9, 0.6, 0.6] (if the whole length is 6)
# this mask is broadcasted and multiplied onto the weight
masks
=
[]
# the accepted argument is list of tuple here
# because slice can't be key of dict
for
sl
,
wt
in
slice_
:
# create a mask with weight w
with
torch
.
no_grad
():
mask
=
zeros_like
(
weight
)
mask
[
_eliminate_list_slice
(
weight
.
shape
,
sl
)]
=
1
# type: ignore
# track gradients here
masks
.
append
(
mask
*
wt
)
# type: ignore
masks
=
sum
(
masks
)
return
masks
*
weight
# type: ignore
else
:
# for unweighted case, we slice it directly.
def
_do_slice
(
arr
,
slice_
):
return
arr
[
_eliminate_list_slice
(
arr
.
shape
,
slice_
)]
# type: ignore
# sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr
=
np
.
zeros
(
weight
.
shape
,
dtype
=
bool
)
# type: ignore
no_effect
=
cast
(
Any
,
_do_slice
(
dummy_arr
,
slice_
)).
shape
==
dummy_arr
.
shape
if
no_effect
:
return
weight
return
_do_slice
(
weight
,
slice_
)
class
Slicable
(
Generic
[
T
]):
"""Wraps the weight so that in can be sliced with a ``multidim_slice``.
The value within the slice can be instances of :class:`MaybeWeighted`.
Examples
--------
>>> weight = conv2d.weight
>>> Slicable(weight)[:MaybeWeighted({32: 0.4, 64: 0.6})]
Tensor of shape (64, 64, 3, 3)
"""
def
__init__
(
self
,
weight
:
T
):
if
not
isinstance
(
weight
,
np
.
ndarray
)
and
not
torch
.
is_tensor
(
weight
):
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
self
.
weight
=
weight
def
__getitem__
(
self
,
index
:
slice_type
|
multidim_slice
|
Any
)
->
T
:
if
not
isinstance
(
index
,
tuple
):
index
=
(
index
,
)
index
=
cast
(
multidim_slice
,
index
)
# Get the dict value in index's leafs
# There can be at most one dict
leaf_dict
:
dict
[
int
,
float
]
|
None
=
None
for
maybe_weighted
in
_iterate_over_multidim_slice
(
index
):
for
d
in
maybe_weighted
.
leaf_values
():
if
isinstance
(
d
,
dict
):
if
leaf_dict
is
None
:
leaf_dict
=
d
elif
leaf_dict
is
not
d
:
raise
ValueError
(
'There can be at most one distinct dict in leaf values.'
)
if
leaf_dict
is
None
:
# in case of simple types with no dict
res_index
=
_evaluate_multidim_slice
(
index
)
else
:
# there is a dict, iterate over dict
res_index
=
[]
for
val
,
wt
in
leaf_dict
.
items
():
res_index_item
=
_evaluate_multidim_slice
(
index
,
lambda
_
:
val
)
res_index
.
append
((
res_index_item
,
wt
))
return
_slice_weight
(
self
.
weight
,
res_index
)
class
MaybeWeighted
:
"""Wrap a value (int or dict with int keys), so that the computation on it can be replayed.
It builds a binary tree. If ``value`` is not None, it's a leaf node.
Otherwise, it has left sub-tree and right sub-tree and an operation.
Only support basic arithmetic operations: ``+``, ``-``, ``*``, ``//``.
"""
def
__init__
(
self
,
value
:
int_or_int_dict
|
None
=
None
,
*
,
lhs
:
'MaybeWeighted'
|
int
|
None
=
None
,
rhs
:
'MaybeWeighted'
|
int
|
None
=
None
,
operation
:
Callable
[[
int_or_int_dict
,
int_or_int_dict
],
int_or_int_dict
]
|
None
=
None
):
if
operation
is
None
:
if
not
isinstance
(
value
,
(
int
,
dict
)):
raise
TypeError
(
f
'Unsupported value type:
{
type
(
value
)
}
'
)
self
.
value
=
value
self
.
lhs
=
lhs
self
.
rhs
=
rhs
self
.
operation
=
operation
def
leaf_values
(
self
)
->
Iterator
[
int_or_int_dict
]:
"""Iterate over values on leaf nodes."""
if
self
.
value
is
not
None
:
yield
self
.
value
else
:
if
isinstance
(
self
.
lhs
,
MaybeWeighted
):
yield
from
self
.
lhs
.
leaf_values
()
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
yield
from
self
.
rhs
.
leaf_values
()
def
evaluate
(
self
,
value_fn
:
_value_fn_type
=
None
)
->
int_or_int_dict
:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
"""
if
self
.
value
is
not
None
:
if
value_fn
is
not
None
:
return
value_fn
(
self
.
value
)
return
self
.
value
else
:
if
isinstance
(
self
.
lhs
,
MaybeWeighted
):
eval_lhs
=
self
.
lhs
.
evaluate
(
value_fn
)
else
:
eval_lhs
=
cast
(
int
,
self
.
lhs
)
if
isinstance
(
self
.
rhs
,
MaybeWeighted
):
eval_rhs
=
self
.
rhs
.
evaluate
(
value_fn
)
else
:
eval_rhs
=
cast
(
int
,
self
.
rhs
)
assert
self
.
operation
is
not
None
return
self
.
operation
(
eval_lhs
,
eval_rhs
)
def
__repr__
(
self
):
if
self
.
value
is
not
None
:
return
f
'
{
self
.
__class__
.
__name__
}
(
{
self
.
value
}
)'
return
f
'
{
self
.
__class__
.
__name__
}
(lhs=
{
self
.
lhs
}
, rhs=
{
self
.
rhs
}
, op=
{
self
.
operation
}
)'
def
__add__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
self
,
rhs
=
other
,
operation
=
operator
.
add
)
def
__radd__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
other
,
rhs
=
self
,
operation
=
operator
.
add
)
def
__sub__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
self
,
rhs
=
other
,
operation
=
operator
.
sub
)
def
__rsub__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
other
,
rhs
=
self
,
operation
=
operator
.
sub
)
def
__mul__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
self
,
rhs
=
other
,
operation
=
operator
.
mul
)
def
__rmul__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
other
,
rhs
=
self
,
operation
=
operator
.
mul
)
def
__floordiv__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
self
,
rhs
=
other
,
operation
=
operator
.
floordiv
)
def
__rfloordiv__
(
self
,
other
:
Any
)
->
'MaybeWeighted'
:
return
MaybeWeighted
(
lhs
=
other
,
rhs
=
self
,
operation
=
operator
.
floordiv
)
def
_iterate_over_slice_type
(
s
:
slice_type
):
if
isinstance
(
s
,
list
):
for
se
in
s
:
yield
from
_iterate_over_slice_type
(
se
)
else
:
# s must be a "slice" now
if
isinstance
(
s
.
start
,
MaybeWeighted
):
yield
s
.
start
if
isinstance
(
s
.
stop
,
MaybeWeighted
):
yield
s
.
stop
if
isinstance
(
s
.
step
,
MaybeWeighted
):
yield
s
.
step
def
_iterate_over_multidim_slice
(
ms
:
multidim_slice
):
"""Get :class:`MaybeWeighted` instances in ``ms``."""
for
s
in
ms
:
if
s
is
not
None
and
s
is
not
Ellipsis
:
yield
from
_iterate_over_slice_type
(
s
)
def
_evaluate_slice_type
(
s
:
slice_type
,
value_fn
:
_value_fn_type
=
None
):
if
isinstance
(
s
,
list
):
return
[
_evaluate_slice_type
(
se
,
value_fn
)
for
se
in
s
]
else
:
return
slice
(
s
.
start
.
evaluate
(
value_fn
)
if
isinstance
(
s
.
start
,
MaybeWeighted
)
else
s
.
start
,
s
.
stop
.
evaluate
(
value_fn
)
if
isinstance
(
s
.
stop
,
MaybeWeighted
)
else
s
.
stop
,
s
.
step
.
evaluate
(
value_fn
)
if
isinstance
(
s
.
step
,
MaybeWeighted
)
else
s
.
step
)
def
_evaluate_multidim_slice
(
ms
:
multidim_slice
,
value_fn
:
_value_fn_type
=
None
):
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
res
=
[]
for
s
in
ms
:
if
s
is
not
None
and
s
is
not
Ellipsis
:
res
.
append
(
_evaluate_slice_type
(
s
,
value_fn
))
else
:
res
.
append
(
s
)
return
tuple
(
res
)
from
nni.nas.oneshot.pytorch.supermodule._operation_utils
import
*
nni/retiarii/oneshot/pytorch/supermodule/_singlepathnas.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# type: ignore
# pylint: disable=wildcard-import,unused-wildcard-import
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item.
Keep this file here so that it can be "blamed".
"""
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii.nn.pytorch
import
ValueChoice
class
DifferentiableSuperConv2d
(
nn
.
Conv2d
):
"""
Only ``kernel_size`` ``in_channels`` and ``out_channels`` are supported. Kernel size candidates should be larger or smaller
than each other in both candidates. See examples below:
the following example is not allowed:
>>> ValueChoice(candidates = [(5, 3), (3, 5)])
□ ■ ■ ■ □ □ □ □ □ □
□ ■ ■ ■ □ ■ ■ ■ ■ ■ # candidates are not bigger or smaller on both dimension
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ □ □ □ □ □
the following 3 examples are valid:
>>> ValueChoice(candidates = [5, 3, 1])
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ ■ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
>>> ValueChoice(candidates = [(5, 7), (3, 5), (1, 3)])
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ ■ ■ ■ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
>>> # when the difference between any two candidates is not even, the left upper will be picked:
>>> ValueChoice(candidates = [(5, 5), (4, 4), (3, 3)])
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
"""
def
__init__
(
self
,
module
,
name
):
self
.
label
=
name
args
=
module
.
trace_kwargs
# compulsory params
if
isinstance
(
args
[
'in_channels'
],
ValueChoice
):
args
[
'in_channels'
]
=
max
(
args
[
'in_channels'
].
candidates
)
self
.
out_channel_candidates
=
None
if
isinstance
(
args
[
'out_channels'
],
ValueChoice
):
self
.
out_channel_candidates
=
sorted
(
args
[
'out_channels'
].
candidates
,
reverse
=
True
)
args
[
'out_channels'
]
=
self
.
out_channel_candidates
[
0
]
# kernel_size may be an int or tuple, we turn it into a tuple for simplicity
self
.
kernel_size_candidates
=
None
if
isinstance
(
args
[
'kernel_size'
],
ValueChoice
):
# unify kernel size as tuple
candidates
=
args
[
'kernel_size'
].
candidates
if
not
isinstance
(
candidates
[
0
],
tuple
):
candidates
=
[(
k
,
k
)
for
k
in
candidates
]
# sort kernel size in descending order
self
.
kernel_size_candidates
=
sorted
(
candidates
,
key
=
lambda
t
:
t
[
0
],
reverse
=
True
)
for
i
in
range
(
0
,
len
(
self
.
kernel_size_candidates
)
-
1
):
bigger
=
self
.
kernel_size_candidates
[
i
]
smaller
=
self
.
kernel_size_candidates
[
i
+
1
]
assert
bigger
[
1
]
>
smaller
[
1
]
or
(
bigger
[
1
]
==
smaller
[
1
]
and
bigger
[
0
]
>
smaller
[
0
]),
f
'Kernel_size candidates '
\
f
'should be larger or smaller than each other on both dimensions, but found
{
bigger
}
and
{
smaller
}
.'
args
[
'kernel_size'
]
=
self
.
kernel_size_candidates
[
0
]
super
().
__init__
(
**
args
)
self
.
generate_architecture_params
()
def
forward
(
self
,
input
):
# Note that there is no need to handle ``in_channels`` here since it is already handle by the ``out_channels`` in the
# previous module. If we multiply alpha with refer to ``in_channels`` here again, the alpha will indeed be considered
# twice, which is not what we expect.
weight
=
self
.
weight
def
sum_weight
(
input_weight
,
masks
,
thresholds
,
indicator
):
"""
This is to get the weighted sum of weight.
Parameters
----------
input_weight : Tensor
the weight to be weighted summed
masks : list[Tensor]
weight masks.
thresholds : list[float]
thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight
Returns
----------
weight : Tensor
weighted sum of ``input_weight``. this is of the same shape as ``input_sum``
"""
# Note that ``masks`` and ``thresholds`` have different lengths. There alignment is shown below:
# self.xxx_candidates = [ c_0 , c_1 , ... , c_n-2 , c_n-1 ] # descending order
# self.xxx_mask = [ mask_0 , mask_1 , ... , mask_n-2, mask_n-1]
# self.t_xxx = [ t_0 , t_2 , ... , t_n-2 ]
# So we zip the first n-1 items, and multiply masks[-1] in the end.
weight
=
torch
.
zeros_like
(
input_weight
)
for
mask
,
t
in
zip
(
masks
[:
-
1
],
thresholds
):
cur_part
=
input_weight
*
mask
alpha
=
indicator
(
cur_part
,
t
)
weight
=
(
weight
+
cur_part
)
*
alpha
# we do not consider skip-op here for out_channel/expansion candidates, which means at least the smallest channel
# candidate is included
weight
+=
input_weight
*
masks
[
-
1
]
return
weight
if
self
.
kernel_size_candidates
is
not
None
:
weight
=
sum_weight
(
weight
,
self
.
kernel_masks
,
self
.
t_kernel
,
self
.
Lasso_sigmoid
)
if
self
.
out_channel_candidates
is
not
None
:
weight
=
sum_weight
(
weight
,
self
.
channel_masks
,
self
.
t_expansion
,
self
.
Lasso_sigmoid
)
output
=
self
.
_conv_forward
(
input
,
weight
,
self
.
bias
)
return
output
def
parameters
(
self
):
for
_
,
p
in
self
.
named_parameters
():
yield
p
def
named_parameters
(
self
):
for
name
,
p
in
super
().
named_parameters
():
if
name
==
'alpha'
:
continue
yield
name
,
p
def
export
(
self
):
"""
result = {
'kernel_size': i,
'out_channels': j
}
which means the best candidate for an argument is the i-th one if candidates are sorted in descending order
"""
result
=
{}
eps
=
1e-5
with
torch
.
no_grad
():
if
self
.
kernel_size_candidates
is
not
None
:
weight
=
torch
.
zeros_like
(
self
.
weight
)
# ascending order
for
i
in
range
(
len
(
self
.
kernel_size_candidates
)
-
2
,
-
1
,
-
1
):
mask
=
self
.
kernel_masks
[
i
]
t
=
self
.
t_kernel
[
i
]
cur_part
=
self
.
weight
*
mask
alpha
=
self
.
Lasso_sigmoid
(
cur_part
,
t
)
if
alpha
<=
eps
:
# takes the smaller one
result
[
'kernel_size'
]
=
self
.
kernel_size_candidates
[
i
+
1
]
break
weight
=
(
weight
+
cur_part
)
*
alpha
if
'kernel_size'
not
in
result
:
result
[
'kernel_size'
]
=
self
.
kernel_size_candidates
[
0
]
else
:
weight
=
self
.
weight
if
self
.
out_channel_candidates
is
not
None
:
for
i
in
range
(
len
(
self
.
out_channel_candidates
)
-
2
,
-
1
,
-
1
):
mask
=
self
.
channel_masks
[
i
]
t
=
self
.
t_expansion
[
i
]
alpha
=
self
.
Lasso_sigmoid
(
weight
*
mask
,
t
)
if
alpha
<=
eps
:
result
[
'out_channels'
]
=
self
.
out_channel_candidates
[
i
+
1
]
if
'out_channels'
not
in
result
:
result
[
'out_channels'
]
=
self
.
out_channel_candidates
[
0
]
return
result
@
staticmethod
def
Lasso_sigmoid
(
matrix
,
t
):
"""
A trick that can make use of both the value of bool(lasso > t) and the gradient of sigmoid(lasso - t)
Parameters
----------
matrix : Tensor
the matrix to calculate lasso norm
t : float
the threshold
"""
lasso
=
torch
.
norm
(
matrix
)
-
t
indicator
=
(
lasso
>
0
).
float
()
# torch.sign(lasso)
with
torch
.
no_grad
():
# indicator = indicator / 2 + .5 # realign indicator from (-1, 1) to (0, 1)
indicator
-=
F
.
sigmoid
(
lasso
)
indicator
+=
F
.
sigmoid
(
lasso
)
return
indicator
def
generate_architecture_params
(
self
):
self
.
alpha
=
{}
if
self
.
kernel_size_candidates
is
not
None
:
# kernel size arch params
self
.
t_kernel
=
nn
.
Parameter
(
torch
.
rand
(
len
(
self
.
kernel_size_candidates
)
-
1
))
self
.
alpha
[
'kernel_size'
]
=
self
.
t_kernel
# kernel size mask
self
.
kernel_masks
=
[]
for
i
in
range
(
0
,
len
(
self
.
kernel_size_candidates
)
-
1
):
big_size
=
self
.
kernel_size_candidates
[
i
]
small_size
=
self
.
kernel_size_candidates
[
i
+
1
]
mask
=
torch
.
zeros_like
(
self
.
weight
)
mask
[:,
:,
:
big_size
[
0
],
:
big_size
[
1
]]
=
1
# if self.weight.shape = (out, in, 7, 7), big_size = (5, 5) and
mask
[:,
:,
:
small_size
[
0
],
:
small_size
[
1
]]
=
0
# small_size = (3, 3), mask will look like:
self
.
kernel_masks
.
append
(
mask
)
# 0 0 0 0 0 0 0
mask
=
torch
.
zeros_like
(
self
.
weight
)
# 0 1 1 1 1 1 0
mask
[:,
:,
:
self
.
kernel_size_candidates
[
-
1
][
0
],
:
self
.
kernel_size_candidates
[
-
1
][
1
]]
=
1
# 0 1 0 0 0 1 0
self
.
kernel_masks
.
append
(
mask
)
# 0 1 0 0 0 1 0
# 0 1 0 0 0 1 0
if
self
.
out_channel_candidates
is
not
None
:
# 0 1 1 1 1 1 0
# out_channel (or expansion) arch params. we do not consider skip-op here, so we # 0 0 0 0 0 0 0
# only generate ``len(self.kernel_size_candidates) - 1 `` thresholds
self
.
t_expansion
=
nn
.
Parameter
(
torch
.
rand
(
len
(
self
.
out_channel_candidates
)
-
1
))
self
.
alpha
[
'out_channels'
]
=
self
.
t_expansion
self
.
channel_masks
=
[]
for
i
in
range
(
0
,
len
(
self
.
out_channel_candidates
)
-
1
):
big_channel
,
small_channel
=
self
.
out_channel_candidates
[
i
],
self
.
out_channel_candidates
[
i
+
1
]
mask
=
torch
.
zeros_like
(
self
.
weight
)
mask
[:
big_channel
]
=
1
mask
[:
small_channel
]
=
0
# if self.weight.shape = (32, in, W, H), big_channel = 16 and small_size = 8, mask will look like:
# 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
self
.
channel_masks
.
append
(
mask
)
mask
=
torch
.
zeros_like
(
self
.
weight
)
mask
[:
self
.
out_channel_candidates
[
-
1
]]
=
1
self
.
channel_masks
.
append
(
mask
)
class
DifferentiableBatchNorm2d
(
nn
.
BatchNorm2d
):
def
__init__
(
self
,
module
,
name
):
self
.
label
=
name
args
=
module
.
trace_kwargs
if
isinstance
(
args
[
'num_features'
],
ValueChoice
):
args
[
'num_features'
]
=
max
(
args
[
'num_features'
].
candidates
)
super
().
__init__
(
**
args
)
# no architecture parameter is needed for BatchNorm2d Layers
self
.
alpha
=
nn
.
Parameter
(
torch
.
tensor
([]))
def
export
(
self
):
"""
No need to export ``BatchNorm2d``. Refer to the ``Conv2d`` layer that has the ``ValueChoice`` as ``out_channels``.
"""
return
-
1
from
nni.nas.oneshot.pytorch.supermodule._singlepathnas
import
*
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
# pylint: disable=wildcard-import,unused-wildcard-import
from
__future__
import
annotations
import
itertools
from
typing
import
Any
,
TypeVar
,
List
,
cast
,
Mapping
,
Sequence
,
Optional
,
Iterable
import
numpy
as
np
import
torch
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch.api
import
ChoiceOf
,
ValueChoiceX
Choice
=
Any
T
=
TypeVar
(
'T'
)
__all__
=
[
'dedup_inner_choices'
,
'evaluate_value_choice_with_dict'
,
'traverse_all_options'
,
'weighted_sum'
,
'evaluate_constant'
,
]
def
dedup_inner_choices
(
value_choices
:
list
[
ValueChoiceX
])
->
dict
[
str
,
ParameterSpec
]:
"""Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``.
"""
result
=
{}
for
value_choice
in
value_choices
:
for
choice
in
value_choice
.
inner_choices
():
param_spec
=
ParameterSpec
(
choice
.
label
,
'choice'
,
choice
.
candidates
,
(
choice
.
label
,
),
True
,
size
=
len
(
choice
.
candidates
))
if
choice
.
label
in
result
:
if
param_spec
!=
result
[
choice
.
label
]:
raise
ValueError
(
'Value choice conflict: same label with different candidates: '
f
'
{
param_spec
}
vs.
{
result
[
choice
.
label
]
}
'
)
else
:
result
[
choice
.
label
]
=
param_spec
return
result
def
evaluate_value_choice_with_dict
(
value_choice
:
ChoiceOf
[
T
],
chosen
:
dict
[
str
,
Choice
])
->
T
:
"""To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values,
then feed the values into ``value_choice.evaluate``.
This can be potentially optimized in terms of speed.
Examples
--------
>>> chosen = {"exp_ratio": 3}
>>> evaluate_value_choice_with_dict(value_choice_in, chosen)
48
>>> evaluate_value_choice_with_dict(value_choice_out, chosen)
96
"""
choice_inner_values
=
[]
for
choice
in
value_choice
.
inner_choices
():
if
choice
.
label
not
in
chosen
:
raise
KeyError
(
f
'
{
value_choice
}
depends on a value with key
{
choice
.
label
}
, but not found in
{
chosen
}
'
)
choice_inner_values
.
append
(
chosen
[
choice
.
label
])
return
value_choice
.
evaluate
(
choice_inner_values
)
def
traverse_all_options
(
value_choice
:
ChoiceOf
[
T
],
weights
:
dict
[
str
,
list
[
float
]]
|
dict
[
str
,
np
.
ndarray
]
|
dict
[
str
,
torch
.
Tensor
]
|
None
=
None
)
->
list
[
tuple
[
T
,
float
]]
|
list
[
T
]:
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
Parameters
----------
value_choice : ValueChoiceX
The value choice to traverse.
weights : Optional[dict[str, list[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
-------
list[Union[tuple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
"""
# get a dict of {label: list of tuple of choice and weight}
leafs
:
dict
[
str
,
list
[
tuple
[
T
,
float
]]]
=
{}
for
label
,
param_spec
in
dedup_inner_choices
([
value_choice
]).
items
():
if
weights
is
not
None
:
if
label
not
in
weights
:
raise
KeyError
(
f
'
{
value_choice
}
depends on a weight with key
{
label
}
, but not found in
{
weights
}
'
)
if
len
(
weights
[
label
])
!=
param_spec
.
size
:
raise
KeyError
(
f
'Expect weights with
{
label
}
to be of length
{
param_spec
.
size
}
, but
{
len
(
weights
[
label
])
}
found'
)
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
cast
(
List
[
float
],
weights
[
label
])))
else
:
# create a dummy weight of zero, in case that weights are not provided.
leafs
[
label
]
=
list
(
zip
(
param_spec
.
values
,
itertools
.
repeat
(
0.
,
param_spec
.
size
)))
# result is a dict from a option to its weight
result
:
dict
[
T
,
float
|
None
]
=
{}
labels
,
values
=
list
(
leafs
.
keys
()),
list
(
leafs
.
values
())
if
not
labels
:
raise
ValueError
(
f
'There expects at least one leaf value choice in
{
value_choice
}
, but nothing found'
)
# get all combinations
for
prod_value
in
itertools
.
product
(
*
values
):
# For example,
# prod_value = ((3, 0.1), ("cat", 0.3), ({"in": 5}, 0.5))
# the first dim is chosen value, second dim is probability
# chosen = {"ks": 3, "animal": "cat", "linear_args": {"in": 5}}
# chosen_weight = np.prod([0.1, 0.3, 0.5])
chosen
=
{
label
:
value
[
0
]
for
label
,
value
in
zip
(
labels
,
prod_value
)}
eval_res
=
evaluate_value_choice_with_dict
(
value_choice
,
chosen
)
if
weights
is
None
:
result
[
eval_res
]
=
None
else
:
# we can't use reduce or inplace product here,
# because weight can sometimes be tensors
chosen_weight
=
prod_value
[
0
][
1
]
for
value
in
prod_value
[
1
:]:
if
chosen_weight
is
None
:
chosen_weight
=
value
[
1
]
else
:
chosen_weight
=
chosen_weight
*
value
[
1
]
if
eval_res
in
result
:
result
[
eval_res
]
=
result
[
eval_res
]
+
chosen_weight
else
:
result
[
eval_res
]
=
chosen_weight
if
weights
is
None
:
return
sorted
(
result
.
keys
())
# type: ignore
else
:
return
sorted
(
result
.
items
())
# type: ignore
def
evaluate_constant
(
expr
:
Any
)
->
Any
:
"""Evaluate a value choice expression to a constant. Raise ValueError if it's not a constant."""
all_options
=
traverse_all_options
(
expr
)
if
len
(
all_options
)
>
1
:
raise
ValueError
(
f
'
{
expr
}
is not evaluated to a constant. All possible values are:
{
all_options
}
'
)
res
=
all_options
[
0
]
return
res
def
weighted_sum
(
items
:
list
[
T
],
weights
:
Sequence
[
float
|
None
]
=
cast
(
Sequence
[
Optional
[
float
]],
None
))
->
T
:
"""Return a weighted sum of items.
Items can be list of tensors, numpy arrays, or nested lists / dicts.
If ``weights`` is None, this is simply an unweighted sum.
"""
if
weights
is
None
:
weights
=
[
None
]
*
len
(
items
)
assert
len
(
items
)
==
len
(
weights
)
>
0
elem
=
items
[
0
]
unsupported_msg
=
f
'Unsupported element type in weighted sum:
{
type
(
elem
)
}
. Value is:
{
elem
}
'
if
isinstance
(
elem
,
str
):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise
TypeError
(
unsupported_msg
)
try
:
if
isinstance
(
elem
,
(
torch
.
Tensor
,
np
.
ndarray
,
float
,
int
,
np
.
number
)):
if
weights
[
0
]
is
None
:
res
=
elem
else
:
res
=
elem
*
weights
[
0
]
for
it
,
weight
in
zip
(
items
[
1
:],
weights
[
1
:]):
if
type
(
it
)
!=
type
(
elem
):
raise
TypeError
(
f
'Expect type
{
type
(
elem
)
}
but found
{
type
(
it
)
}
. Can not be summed'
)
if
weight
is
None
:
res
=
res
+
it
# type: ignore
else
:
res
=
res
+
it
*
weight
# type: ignore
return
cast
(
T
,
res
)
if
isinstance
(
elem
,
Mapping
):
for
item
in
items
:
if
not
isinstance
(
item
,
Mapping
):
raise
TypeError
(
f
'Expect type
{
type
(
elem
)
}
but found
{
type
(
item
)
}
'
)
if
set
(
item
)
!=
set
(
elem
):
raise
KeyError
(
f
'Expect keys
{
list
(
elem
)
}
but found
{
list
(
item
)
}
'
)
return
cast
(
T
,
{
key
:
weighted_sum
(
cast
(
List
[
dict
],
[
cast
(
Mapping
,
d
)[
key
]
for
d
in
items
]),
weights
)
for
key
in
elem
})
if
isinstance
(
elem
,
Sequence
):
for
item
in
items
:
if
not
isinstance
(
item
,
Sequence
):
raise
TypeError
(
f
'Expect type
{
type
(
elem
)
}
but found
{
type
(
item
)
}
'
)
if
len
(
item
)
!=
len
(
elem
):
raise
ValueError
(
f
'Expect length
{
len
(
item
)
}
but found
{
len
(
elem
)
}
'
)
transposed
=
cast
(
Iterable
[
list
],
zip
(
*
items
))
# type: ignore
return
cast
(
T
,
[
weighted_sum
(
column
,
weights
)
for
column
in
transposed
])
except
(
TypeError
,
ValueError
,
RuntimeError
,
KeyError
):
raise
ValueError
(
'Error when summing items. Value format / shape does not match. See full traceback for details.'
+
''
.
join
([
f
'
\n
{
idx
}
:
{
_summarize_elem_format
(
it
)
}
'
for
idx
,
it
in
enumerate
(
items
)
])
)
# Dealing with all unexpected types.
raise
TypeError
(
unsupported_msg
)
def
_summarize_elem_format
(
elem
:
Any
)
->
Any
:
# Get a summary of one elem
# Helps generate human-readable error messages
class
_repr_object
:
# empty object is only repr
def
__init__
(
self
,
representation
):
self
.
representation
=
representation
def
__repr__
(
self
):
return
self
.
representation
if
isinstance
(
elem
,
torch
.
Tensor
):
return
_repr_object
(
'torch.Tensor('
+
', '
.
join
(
map
(
str
,
elem
.
shape
))
+
')'
)
if
isinstance
(
elem
,
np
.
ndarray
):
return
_repr_object
(
'np.array('
+
', '
.
join
(
map
(
str
,
elem
.
shape
))
+
')'
)
if
isinstance
(
elem
,
Mapping
):
return
{
key
:
_summarize_elem_format
(
value
)
for
key
,
value
in
elem
.
items
()}
if
isinstance
(
elem
,
Sequence
):
return
[
_summarize_elem_format
(
value
)
for
value
in
elem
]
# fallback to original, for cases like float, int, ...
return
elem
from
nni.nas.oneshot.pytorch.supermodule._valuechoice_utils
import
*
nni/retiarii/oneshot/pytorch/supermodule/base.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
# pylint: disable=wildcard-import,unused-wildcard-import
from
typing
import
Any
import
torch.nn
as
nn
from
nni.common.hpo_utils
import
ParameterSpec
__all__
=
[
'BaseSuperNetModule'
]
class
BaseSuperNetModule
(
nn
.
Module
):
"""
Mutated module in super-net.
Usually, the feed-forward of the module itself is undefined.
It has to be resampled with ``resample()`` so that a specific path is selected.
(Sometimes, this is not required. For example, differentiable super-net.)
A super-net module usually corresponds to one sample. But two exceptions:
* A module can have multiple parameter spec. For example, a convolution-2d can sample kernel size, channels at the same time.
* Multiple modules can share one parameter spec. For example, multiple layer choices with the same label.
For value choice compositions, the parameter spec are bounded to the underlying (original) value choices,
rather than their compositions.
"""
def
resample
(
self
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""
Resample the super-net module.
Parameters
----------
memo : dict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns
-------
dict
Sampled result. If nothing new is sampled, it should return an empty dict.
"""
raise
NotImplementedError
()
def
export
(
self
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
Parameters
----------
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise
NotImplementedError
()
def
search_space_spec
(
self
)
->
dict
[
str
,
ParameterSpec
]:
"""
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
For example: ::
{"layer1": ParameterSpec(values=["conv", "pool"])}
"""
raise
NotImplementedError
()
@
classmethod
def
mutate
(
cls
,
module
:
nn
.
Module
,
name
:
str
,
memo
:
dict
[
str
,
Any
],
mutate_kwargs
:
dict
[
str
,
Any
])
->
\
'BaseSuperNetModule'
|
bool
|
tuple
[
'BaseSuperNetModule'
,
bool
]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on.
Parameters
----------
module : nn.Module
The module to be mutated (replaced).
name : str
Name of this module. With full prefix. For example, ``module1.block1.conv``.
memo : dict
Memo to enable sharing parameters among mutated modules. It should be read and written by
mutate functions themselves.
mutate_kwargs : dict
Algo-related hyper-parameters, and some auxiliary information.
Returns
-------
Union[BaseSuperNetModule, bool, tuple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details.
"""
raise
NotImplementedError
()
from
nni.nas.oneshot.pytorch.supermodule.base
import
*
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import
functools
import
logging
import
warnings
from
typing
import
Any
,
Dict
,
Sequence
,
List
,
Tuple
,
cast
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ChoiceOf
,
Repeat
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.retiarii.nn.pytorch.cell
import
preprocess_cell_inputs
from
.base
import
BaseSuperNetModule
from
.operation
import
MixedOperation
,
MixedOperationSamplingPolicy
from
.sampling
import
PathSamplingCell
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
,
weighted_sum
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'DifferentiableMixedLayer'
,
'DifferentiableMixedInput'
,
'DifferentiableMixedRepeat'
,
'DifferentiableMixedCell'
,
'MixedOpDifferentiablePolicy'
,
]
class
GumbelSoftmax
(
nn
.
Softmax
):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
dim
:
int
def
__init__
(
self
,
dim
:
int
=
-
1
)
->
None
:
super
().
__init__
(
dim
)
self
.
tau
=
1
self
.
hard
=
False
def
forward
(
self
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
gumbel_softmax
(
inputs
,
tau
=
self
.
tau
,
hard
=
self
.
hard
,
dim
=
self
.
dim
)
class
DifferentiableMixedLayer
(
BaseSuperNetModule
):
"""
Mixed layer, in which fprop is decided by a weighted sum of several layers.
Proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__.
The weight ``alpha`` is usually learnable, and optimized on validation dataset.
Differentiable sampling layer requires all operators returning the same shape for one input,
as all outputs will be weighted summed to get the final output.
Parameters
----------
paths : list[tuple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
op_names : str
Operator names.
label : str
Name of the choice.
"""
_arch_parameter_names
:
list
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
paths
:
list
[
tuple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
self
.
op_names
=
[]
if
len
(
alpha
)
!=
len
(
paths
):
raise
ValueError
(
f
'The size of alpha (
{
len
(
alpha
)
}
) must match number of candidates (
{
len
(
paths
)
}
).'
)
for
name
,
module
in
paths
:
self
.
add_module
(
name
,
module
)
self
.
op_names
.
append
(
name
)
assert
self
.
op_names
,
'There has to be at least one op to choose from.'
self
.
label
=
label
self
.
_arch_alpha
=
alpha
self
.
_softmax
=
softmax
def
resample
(
self
,
memo
):
"""Do nothing. Differentiable layer doesn't need resample."""
return
{}
def
export
(
self
,
memo
):
"""Choose the operator with the maximum logit."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
())]}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
True
,
size
=
len
(
self
.
op_names
))}
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
LayerChoice
):
size
=
len
(
module
)
if
module
.
label
in
memo
:
alpha
=
memo
[
module
.
label
]
if
len
(
alpha
)
!=
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
module
.
label
}
conflict:
{
len
(
alpha
)
}
vs.
{
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# this can be reinitialized later
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
list
(
module
.
named_children
()),
alpha
,
softmax
,
module
.
label
)
def
reduction
(
self
,
items
:
list
[
Any
],
weights
:
list
[
float
])
->
Any
:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return
weighted_sum
(
items
,
weights
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
all_op_results
=
[
getattr
(
self
,
op
)(
*
args
,
**
kwargs
)
for
op
in
self
.
op_names
]
return
self
.
reduction
(
all_op_results
,
self
.
_softmax
(
self
.
_arch_alpha
))
def
parameters
(
self
,
*
args
,
**
kwargs
):
"""Parameters excluding architecture parameters."""
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
"""Named parameters excluding architecture parameters."""
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
().
named_parameters
(
*
args
,
**
kwargs
):
if
any
(
name
==
par_name
for
par_name
in
self
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
class
DifferentiableMixedInput
(
BaseSuperNetModule
):
"""
Mixed input. Forward returns a weighted sum of candidates.
Implementation is very similar to :class:`DifferentiableMixedLayer`.
Parameters
----------
n_candidates : int
Expect number of input candidates.
n_chosen : int
Expect numebr of inputs finally chosen.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
label : str
Name of the choice.
"""
_arch_parameter_names
:
list
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
self
.
n_candidates
=
n_candidates
if
len
(
alpha
)
!=
n_candidates
:
raise
ValueError
(
f
'The size of alpha (
{
len
(
alpha
)
}
) must match number of candidates (
{
n_candidates
}
).'
)
if
n_chosen
is
None
:
warnings
.
warn
(
'Differentiable architecture search does not support choosing multiple inputs. Assuming one.'
,
RuntimeWarning
)
self
.
n_chosen
=
1
self
.
n_chosen
=
n_chosen
self
.
label
=
label
self
.
_softmax
=
softmax
self
.
_arch_alpha
=
alpha
def
resample
(
self
,
memo
):
"""Do nothing. Differentiable layer doesn't need resample."""
return
{}
def
export
(
self
,
memo
):
"""Choose the operator with the top ``n_chosen`` logits."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
chosen
=
sorted
(
torch
.
argsort
(
-
self
.
_arch_alpha
).
cpu
().
numpy
().
tolist
()[:
self
.
n_chosen
])
if
len
(
chosen
)
==
1
:
chosen
=
chosen
[
0
]
return
{
self
.
label
:
chosen
}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
list
(
range
(
self
.
n_candidates
)),
(
self
.
label
,
),
True
,
size
=
self
.
n_candidates
,
chosen_size
=
self
.
n_chosen
)
}
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
InputChoice
):
if
module
.
reduction
not
in
[
'sum'
,
'mean'
]:
raise
ValueError
(
'Only input choice of sum/mean reduction is supported.'
)
size
=
module
.
n_candidates
if
module
.
label
in
memo
:
alpha
=
memo
[
module
.
label
]
if
len
(
alpha
)
!=
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
module
.
label
}
conflict:
{
len
(
alpha
)
}
vs.
{
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
size
)
*
1E-3
)
# this can be reinitialized later
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
alpha
,
softmax
,
module
.
label
)
def
reduction
(
self
,
items
:
list
[
Any
],
weights
:
list
[
float
])
->
Any
:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return
weighted_sum
(
items
,
weights
)
def
forward
(
self
,
inputs
):
"""Forward takes a list of input candidates."""
return
self
.
reduction
(
inputs
,
self
.
_softmax
(
self
.
_arch_alpha
))
def
parameters
(
self
,
*
args
,
**
kwargs
):
"""Parameters excluding architecture parameters."""
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
"""Named parameters excluding architecture parameters."""
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
().
named_parameters
(
*
args
,
**
kwargs
):
if
any
(
name
==
par_name
for
par_name
in
self
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
class
MixedOpDifferentiablePolicy
(
MixedOperationSamplingPolicy
):
"""Implementes the differentiable sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters``
filters out multiple parameters with ``_arch_alpha`` as its prefix.
When this class is asked for ``forward_argument``, it returns a distribution,
i.e., a dict from int to float based on its weights.
All the parameters (``_arch_alpha``, ``parameters()``, ``_softmax``) are
saved as attributes of ``operation``, rather than ``self``,
because this class itself is not a ``nn.Module``, and saved parameters here
won't be optimized.
"""
_arch_parameter_names
:
list
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
],
mutate_kwargs
:
dict
[
str
,
Any
])
->
None
:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation
.
_arch_alpha
=
nn
.
ParameterDict
()
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
name
in
memo
:
alpha
=
memo
[
name
]
if
len
(
alpha
)
!=
spec
.
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
name
}
conflict:
{
len
(
alpha
)
}
vs.
{
spec
.
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
operation
.
_arch_alpha
[
name
]
=
alpha
operation
.
parameters
=
functools
.
partial
(
self
.
parameters
,
module
=
operation
)
# bind self
operation
.
named_parameters
=
functools
.
partial
(
self
.
named_parameters
,
module
=
operation
)
operation
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
@
staticmethod
def
parameters
(
module
,
*
args
,
**
kwargs
):
for
_
,
p
in
module
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
@
staticmethod
def
named_parameters
(
module
,
*
args
,
**
kwargs
):
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
(
module
.
__class__
,
module
).
named_parameters
(
*
args
,
**
kwargs
):
# pylint: disable=bad-super-call
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""Differentiable. Do nothing in resample."""
return
{}
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""Export is argmax for each leaf value choice."""
result
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
name
in
memo
:
continue
chosen_index
=
int
(
torch
.
argmax
(
cast
(
dict
,
operation
.
_arch_alpha
)[
name
]).
item
())
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
dict
[
Any
,
float
]
|
Any
:
if
name
in
operation
.
mutable_arguments
:
weights
:
dict
[
str
,
torch
.
Tensor
]
=
{
label
:
cast
(
nn
.
Module
,
operation
.
_softmax
)(
alpha
)
for
label
,
alpha
in
cast
(
dict
,
operation
.
_arch_alpha
).
items
()
}
return
dict
(
traverse_all_options
(
operation
.
mutable_arguments
[
name
],
weights
=
weights
))
return
operation
.
init_arguments
[
name
]
class
DifferentiableMixedRepeat
(
BaseSuperNetModule
):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
If the output is not a single tensor, it will be summed at every independant dimension.
See :func:`weighted_sum` for details.
"""
_arch_parameter_names
:
list
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
blocks
:
list
[
nn
.
Module
],
depth
:
ChoiceOf
[
int
],
softmax
:
nn
.
Module
,
memo
:
dict
[
str
,
Any
]):
super
().
__init__
()
self
.
blocks
=
blocks
self
.
depth
=
depth
self
.
_softmax
=
softmax
self
.
_space_spec
:
dict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
([
depth
])
self
.
_arch_alpha
=
nn
.
ParameterDict
()
for
name
,
spec
in
self
.
_space_spec
.
items
():
if
name
in
memo
:
alpha
=
memo
[
name
]
if
len
(
alpha
)
!=
spec
.
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
name
}
conflict:
{
len
(
alpha
)
}
vs.
{
spec
.
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
self
.
_arch_alpha
[
name
]
=
alpha
def
resample
(
self
,
memo
):
"""Do nothing."""
return
{}
def
export
(
self
,
memo
):
"""Choose argmax for each leaf value choice."""
result
=
{}
for
name
,
spec
in
self
.
_space_spec
.
items
():
if
name
in
memo
:
continue
chosen_index
=
int
(
torch
.
argmax
(
self
.
_arch_alpha
[
name
]).
item
())
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
search_space_spec
(
self
):
return
self
.
_space_spec
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
Repeat
)
and
isinstance
(
module
.
depth_choice
,
ValueChoiceX
):
# Only interesting when depth is mutable
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
cast
(
List
[
nn
.
Module
],
module
.
blocks
),
module
.
depth_choice
,
softmax
,
memo
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
().
named_parameters
(
*
args
,
**
kwargs
):
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
def
reduction
(
self
,
items
:
list
[
Any
],
weights
:
list
[
float
],
depths
:
list
[
int
])
->
Any
:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return
weighted_sum
(
items
,
weights
)
def
forward
(
self
,
x
):
weights
:
dict
[
str
,
torch
.
Tensor
]
=
{
label
:
self
.
_softmax
(
alpha
)
for
label
,
alpha
in
self
.
_arch_alpha
.
items
()
}
depth_weights
=
dict
(
cast
(
List
[
Tuple
[
int
,
float
]],
traverse_all_options
(
self
.
depth
,
weights
=
weights
)))
res
:
list
[
torch
.
Tensor
]
=
[]
weight_list
:
list
[
float
]
=
[]
depths
:
list
[
int
]
=
[]
for
i
,
block
in
enumerate
(
self
.
blocks
,
start
=
1
):
# start=1 because depths are 1, 2, 3, 4...
x
=
block
(
x
)
if
i
in
depth_weights
:
weight_list
.
append
(
depth_weights
[
i
])
res
.
append
(
x
)
depths
.
append
(
i
)
return
self
.
reduction
(
res
,
weight_list
,
depths
)
class
DifferentiableMixedCell
(
PathSamplingCell
):
"""Implementation of Cell under differentiable context.
An architecture parameter is created on each edge of the full-connected graph.
"""
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
# Possibly need another refactor here.
def
__init__
(
self
,
op_factory
,
num_nodes
,
num_ops_per_node
,
num_predecessors
,
preprocessor
,
postprocessor
,
concat_dim
,
memo
,
mutate_kwargs
,
label
):
super
().
__init__
(
op_factory
,
num_nodes
,
num_ops_per_node
,
num_predecessors
,
preprocessor
,
postprocessor
,
concat_dim
,
memo
,
mutate_kwargs
,
label
)
self
.
_arch_alpha
=
nn
.
ParameterDict
()
for
i
in
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
):
for
j
in
range
(
i
):
edge_label
=
f
'
{
label
}
/
{
i
}
_
{
j
}
'
op
=
cast
(
List
[
Dict
[
str
,
nn
.
Module
]],
self
.
ops
[
i
-
self
.
num_predecessors
])[
j
]
if
edge_label
in
memo
:
alpha
=
memo
[
edge_label
]
if
len
(
alpha
)
!=
len
(
op
):
raise
ValueError
(
f
'Architecture parameter size of same label
{
edge_label
}
conflict: '
f
'
{
len
(
alpha
)
}
vs.
{
len
(
op
)
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
op
))
*
1E-3
)
self
.
_arch_alpha
[
edge_label
]
=
alpha
self
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
def
resample
(
self
,
memo
):
"""Differentiable doesn't need to resample."""
return
{}
def
export
(
self
,
memo
):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported
=
{}
for
i
in
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
):
# Tuple of (weight, input_index, op_name)
all_weights
:
list
[
tuple
[
float
,
int
,
str
]]
=
[]
for
j
in
range
(
i
):
for
k
,
name
in
enumerate
(
self
.
op_names
):
all_weights
.
append
((
float
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
][
k
].
item
()),
j
,
name
,
))
all_weights
.
sort
(
reverse
=
True
)
# We first prefer inputs from different input_index.
# If we have got no other choices, we start to accept duplicates.
# Therefore we gather first occurrences of distinct input_index to the front.
first_occurrence_index
:
list
[
int
]
=
[
all_weights
.
index
(
# The index of
next
(
filter
(
lambda
t
:
t
[
1
]
==
j
,
all_weights
))
# First occurence of j
)
for
j
in
range
(
i
)
# For j < i
]
first_occurrence_index
.
sort
()
# Keep them ordered too.
all_weights
=
[
all_weights
[
k
]
for
k
in
first_occurrence_index
]
+
\
[
w
for
j
,
w
in
enumerate
(
all_weights
)
if
j
not
in
first_occurrence_index
]
_logger
.
info
(
'Sorted weights in differentiable cell export (node %d): %s'
,
i
,
all_weights
)
for
k
in
range
(
self
.
num_ops_per_node
):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
_
,
j
,
op_name
=
all_weights
[
k
%
len
(
all_weights
)]
exported
[
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
]
=
op_name
exported
[
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
]
=
j
return
exported
def
forward
(
self
,
*
inputs
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]
|
torch
.
Tensor
:
processed_inputs
:
list
[
torch
.
Tensor
]
=
preprocess_cell_inputs
(
self
.
num_predecessors
,
*
inputs
)
states
:
list
[
torch
.
Tensor
]
=
self
.
preprocessor
(
processed_inputs
)
for
i
,
ops
in
enumerate
(
cast
(
Sequence
[
Sequence
[
Dict
[
str
,
nn
.
Module
]]],
self
.
ops
),
start
=
self
.
num_predecessors
):
current_state
=
[]
for
j
in
range
(
i
):
# for every previous tensors
op_results
=
torch
.
stack
([
op
(
states
[
j
])
for
op
in
ops
[
j
].
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
edge_sum
=
torch
.
sum
(
op_results
*
self
.
_softmax
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
]).
view
(
*
alpha_shape
),
0
)
current_state
.
append
(
edge_sum
)
states
.
append
(
sum
(
current_state
))
# type: ignore
# Always merge all
this_cell
=
torch
.
cat
(
states
[
self
.
num_predecessors
:],
self
.
concat_dim
)
return
self
.
postprocessor
(
this_cell
,
processed_inputs
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
().
named_parameters
(
*
args
,
**
kwargs
):
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
from
nni.nas.oneshot.pytorch.supermodule.differentiable
import
*
nni/retiarii/oneshot/pytorch/supermodule/operation.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Operations that support weight sharing at a fine-grained level,
which is commonly known as super-kernel (as in channel search), or weight entanglement.
"""
# pylint: disable=wildcard-import,unused-wildcard-import
from
__future__
import
annotations
import
inspect
import
itertools
import
warnings
from
typing
import
Any
,
Type
,
TypeVar
,
cast
,
Union
,
Tuple
,
List
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
import
nni.retiarii.nn.pytorch
as
retiarii_nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.serializer
import
is_traceable
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
.base
import
BaseSuperNetModule
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
,
evaluate_constant
from
._operation_utils
import
Slicable
as
_S
,
MaybeWeighted
as
_W
,
int_or_int_dict
,
scalar_or_scalar_dict
T
=
TypeVar
(
'T'
)
__all__
=
[
'MixedOperationSamplingPolicy'
,
'MixedOperation'
,
'MixedLinear'
,
'MixedConv2d'
,
'MixedBatchNorm2d'
,
'MixedLayerNorm'
,
'MixedMultiHeadAttention'
,
'NATIVE_MIXED_OPERATIONS'
,
]
_diff_not_compatible_error
=
'To be compatible with differentiable one-shot strategy, {} in {} must not be ValueChoice.'
class
MixedOperationSamplingPolicy
:
"""
Algo-related part for mixed Operation.
:class:`MixedOperation` delegates its resample and export to this policy (or its subclass),
so that one Operation can be easily combined with different kinds of sampling.
One SamplingStrategy corresponds to one mixed operation.
"""
def
__init__
(
self
,
operation
:
'MixedOperation'
,
memo
:
dict
[
str
,
Any
],
mutate_kwargs
:
dict
[
str
,
Any
])
->
None
:
"""At init, the sampling policy can prepare basic parameters,
and store them in operation if they need back propagation.
This init is called in :meth:`BaseSuperNetModule.mutate`, after the mixed operation is created.
So similar to :meth:`BaseSuperNetModule.mutate`,
memo should also be managed (read and written) by the policy itself.
"""
pass
def
resample
(
self
,
operation
:
'MixedOperation'
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.resample`."""
raise
NotImplementedError
()
def
export
(
self
,
operation
:
'MixedOperation'
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""The handler of :meth:`MixedOperation.export`."""
raise
NotImplementedError
()
def
forward_argument
(
self
,
operation
:
'MixedOperation'
,
name
:
str
)
->
Any
:
"""Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value.
"""
raise
NotImplementedError
()
class
MixedOperation
(
BaseSuperNetModule
):
"""This is the base class for all mixed operations.
It's what you should inherit to support a new operation with ValueChoice.
It contains commonly used utilities that will ease the effort to write customized mixed oeprations,
i.e., operations with ValueChoice in its arguments.
To customize, please write your own mixed operation, and add the hook into ``mutation_hooks`` parameter when using the strategy.
By design, for a mixed operation to work in a specific algorithm,
at least two classes are needed.
1. One class needs to inherit this class, to control operation-related behavior,
such as how to initialize the operation such that the sampled operation can be its sub-operation.
2. The other one needs to inherit :class:`MixedOperationSamplingPolicy`,
which controls algo-related behavior, such as sampling.
The two classes are linked with ``sampling_policy`` attribute in :class:`MixedOperation`,
whose type is set via ``mixed_op_sampling`` in ``mutate_kwargs`` when
:meth:`MixedOperation.mutate` is called.
With this design, one mixed-operation (e.g., MixedConv2d) can work in multiple algorithms
(e.g., both DARTS and ENAS), saving the engineering effort to rewrite all operations for
each specific algo.
This class should also define a ``bound_type``, to control the matching type in mutate,
an ``argument_list``, to control which arguments can be dynamically used in ``forward``.
This list will also be used in mutate for sanity check.
"""
bound_type
:
Type
[
nn
.
Module
]
# defined in subclass
argument_list
:
list
[
str
]
# defined in subclass
sampling_policy
:
MixedOperationSamplingPolicy
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
)
->
Any
:
"""Get the initialization argument when constructing super-kernel, i.e., calling ``super().__init__()``.
This is often related to specific operator, rather than algo.
For example::
def super_init_argument(self, name, value_choice):
return max(value_choice.candidates)
"""
raise
NotImplementedError
()
def
__post_init__
(
self
)
->
None
:
"""Can be used to validate, or to do extra processing after calling ``__init__``."""
pass
def
forward_with_args
(
self
,
*
args
,
**
kwargs
):
"""To control real fprop. The accepted arguments are ``argument_list``,
appended by forward arguments in the ``bound_type``."""
raise
NotImplementedError
()
def
__init__
(
self
,
module_kwargs
:
dict
[
str
,
Any
])
->
None
:
# Concerned arguments
self
.
mutable_arguments
:
dict
[
str
,
ValueChoiceX
]
=
{}
# Useful when retrieving arguments without ValueChoice
self
.
init_arguments
:
dict
[
str
,
Any
]
=
{
**
module_kwargs
}
self
.
_fill_missing_init_arguments
()
# get init default
super_init_kwargs
=
{}
for
key
,
value
in
module_kwargs
.
items
():
if
isinstance
(
value
,
ValueChoiceX
):
if
key
not
in
self
.
argument_list
:
raise
TypeError
(
f
'Unsupported value choice on argument of
{
self
.
bound_type
}
:
{
key
}
'
)
super_init_kwargs
[
key
]
=
self
.
super_init_argument
(
key
,
value
)
self
.
mutable_arguments
[
key
]
=
value
else
:
super_init_kwargs
[
key
]
=
value
# get all inner leaf value choices
self
.
_space_spec
:
dict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
(
list
(
self
.
mutable_arguments
.
values
()))
super
().
__init__
(
**
super_init_kwargs
)
self
.
__post_init__
()
def
resample
(
self
,
memo
):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return
self
.
sampling_policy
.
resample
(
self
,
memo
)
def
export
(
self
,
memo
):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return
self
.
sampling_policy
.
export
(
self
,
memo
)
def
search_space_spec
(
self
):
return
self
.
_space_spec
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
"""Find value choice in module's arguments and replace the whole module"""
has_valuechoice
=
False
if
isinstance
(
module
,
cls
.
bound_type
)
and
is_traceable
(
module
):
for
arg
in
itertools
.
chain
(
cast
(
list
,
module
.
trace_args
),
cast
(
dict
,
module
.
trace_kwargs
).
values
()):
if
isinstance
(
arg
,
ValueChoiceX
):
has_valuechoice
=
True
if
has_valuechoice
:
if
module
.
trace_args
:
raise
ValueError
(
'ValueChoice on class arguments cannot appear together with ``trace_args``. '
'Please enable ``kw_only`` on nni.trace.'
)
# save type and kwargs
mixed_op
=
cls
(
cast
(
dict
,
module
.
trace_kwargs
))
if
'mixed_op_sampling'
not
in
mutate_kwargs
:
raise
ValueError
(
"Need a sampling policy for mixed op, but it's not found in `mutate_kwargs`."
)
policy_cls
:
Type
[
MixedOperationSamplingPolicy
]
=
mutate_kwargs
[
'mixed_op_sampling'
]
# initialize policy class
# this is put in mutate because we need to access memo
mixed_op
.
sampling_policy
=
policy_cls
(
mixed_op
,
memo
,
mutate_kwargs
)
return
mixed_op
def
forward_argument
(
self
,
name
:
str
)
->
Any
:
"""Get the argument used in forward.
This if often related to algo. We redirect this to sampling policy.
"""
return
self
.
sampling_policy
.
forward_argument
(
self
,
name
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""First get sampled arguments, then forward with the sampled arguments (by calling ``forward_with_args``)."""
sampled_args
=
[
self
.
forward_argument
(
name
)
for
name
in
self
.
argument_list
]
return
self
.
forward_with_args
(
*
sampled_args
,
*
args
,
**
kwargs
)
def
_fill_missing_init_arguments
(
self
)
->
None
:
"""Set the unspecified init arguments in ``self.init_arguments``.
For example, in the case of Conv2d, when user didn't specify argument ``stride``,
this method adds ``stride = 1`` in ``self.init_arguments``.
This is implemented by inspecting the init signature of ``bound_type``.
Arguments in complex cases like ``__new__`` or in super-class is not supported.
"""
def
unwrap
(
cls
):
if
not
hasattr
(
cls
,
'__wrapped__'
):
return
cls
return
unwrap
(
cls
.
__wrapped__
)
for
param
in
inspect
.
signature
(
unwrap
(
self
.
bound_type
).
__init__
).
parameters
.
values
():
if
param
.
default
is
not
param
.
empty
and
param
.
name
not
in
self
.
init_arguments
:
self
.
init_arguments
[
param
.
name
]
=
param
.
default
class
MixedLinear
(
MixedOperation
,
nn
.
Linear
):
"""Mixed linear operation.
Supported arguments are:
- ``in_features``
- ``out_features``
Prefix of weight and bias will be sliced.
"""
bound_type
=
retiarii_nn
.
Linear
argument_list
=
[
'in_features'
,
'out_features'
]
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
in_features
:
int_or_int_dict
,
out_features
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
in_features_
=
_W
(
in_features
)
out_features_
=
_W
(
out_features
)
weight
=
_S
(
self
.
weight
)[:
out_features_
]
weight
=
_S
(
weight
)[:,
:
in_features_
]
if
self
.
bias
is
None
:
bias
=
self
.
bias
else
:
bias
=
_S
(
self
.
bias
)[:
out_features_
]
return
F
.
linear
(
inputs
,
weight
,
bias
)
_int_or_tuple
=
Union
[
int
,
Tuple
[
int
,
int
]]
class
MixedConv2d
(
MixedOperation
,
nn
.
Conv2d
):
"""Mixed conv2d op.
Supported arguments are:
- ``in_channels``
- ``out_channels``
- ``groups``
- ``stride`` (only supported in path sampling)
- ``kernel_size``
- ``padding``
- ``dilation`` (only supported in path sampling)
``padding`` will be the "max" padding in differentiable mode.
Mutable ``groups`` is NOT supported in most cases of differentiable mode.
However, we do support one special case when the group number is proportional to ``in_channels`` and ``out_channels``.
This is often the case of depth-wise convolutions.
For channels, prefix will be sliced.
For kernels, we take the small kernel from the center and round it to floor (left top). For example ::
max_kernel = 5*5, sampled_kernel = 3*3, then we take [1: 4]
max_kernel = 5*5, sampled_kernel = 2*2, then we take [1: 3]
□ □ □ □ □ □ □ □ □ □
□ ■ ■ ■ □ □ ■ ■ □ □
□ ■ ■ ■ □ □ ■ ■ □ □
□ ■ ■ ■ □ □ □ □ □ □
□ □ □ □ □ □ □ □ □ □
"""
bound_type
=
retiarii_nn
.
Conv2d
argument_list
=
[
'in_channels'
,
'out_channels'
,
'kernel_size'
,
'stride'
,
'padding'
,
'dilation'
,
'groups'
]
@
staticmethod
def
_to_tuple
(
value
:
scalar_or_scalar_dict
[
Any
])
->
tuple
[
Any
,
Any
]:
if
not
isinstance
(
value
,
tuple
):
return
(
value
,
value
)
return
value
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
if
name
not
in
[
'in_channels'
,
'out_channels'
,
'groups'
,
'stride'
,
'kernel_size'
,
'padding'
,
'dilation'
]:
raise
NotImplementedError
(
f
'Unsupported value choice on argument:
{
name
}
'
)
if
name
==
[
'kernel_size'
,
'padding'
]:
all_sizes
=
set
(
traverse_all_options
(
value_choice
))
if
any
(
isinstance
(
sz
,
tuple
)
for
sz
in
all_sizes
):
# maximum kernel should be calculated on every dimension
return
(
max
(
self
.
_to_tuple
(
sz
)[
0
]
for
sz
in
all_sizes
),
max
(
self
.
_to_tuple
(
sz
)[
1
]
for
sz
in
all_sizes
)
)
else
:
return
max
(
all_sizes
)
elif
name
==
'groups'
:
if
'in_channels'
in
self
.
mutable_arguments
:
# If the ratio is constant, we don't need to try the maximum groups.
try
:
constant
=
evaluate_constant
(
self
.
mutable_arguments
[
'in_channels'
]
/
value_choice
)
return
max
(
cast
(
List
[
float
],
traverse_all_options
(
value_choice
)))
//
int
(
constant
)
except
ValueError
:
warnings
.
warn
(
'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. '
'This can be problematic for most one-shot algorithms. Please check whether this is your intention.'
,
RuntimeWarning
)
# minimum groups, maximum kernel
return
min
(
traverse_all_options
(
value_choice
))
else
:
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
in_channels
:
int_or_int_dict
,
out_channels
:
int_or_int_dict
,
kernel_size
:
scalar_or_scalar_dict
[
_int_or_tuple
],
stride
:
_int_or_tuple
,
padding
:
scalar_or_scalar_dict
[
_int_or_tuple
],
dilation
:
int
,
groups
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'stride, dilation'
,
'Conv2d'
))
in_channels_
=
_W
(
in_channels
)
out_channels_
=
_W
(
out_channels
)
# slice prefix
# For groups > 1, we use groups to slice input weights
weight
=
_S
(
self
.
weight
)[:
out_channels_
]
if
not
isinstance
(
groups
,
dict
):
weight
=
_S
(
weight
)[:,
:
in_channels_
//
groups
]
else
:
assert
'groups'
in
self
.
mutable_arguments
err_message
=
'For differentiable one-shot strategy, when groups is a ValueChoice, '
\
'in_channels and out_channels should also be a ValueChoice. '
\
'Also, the ratios of in_channels divided by groups, and out_channels divided by groups '
\
'should be constants.'
if
'in_channels'
not
in
self
.
mutable_arguments
or
'out_channels'
not
in
self
.
mutable_arguments
:
raise
ValueError
(
err_message
)
try
:
in_channels_per_group
=
evaluate_constant
(
self
.
mutable_arguments
[
'in_channels'
]
/
self
.
mutable_arguments
[
'groups'
])
except
ValueError
:
raise
ValueError
(
err_message
)
if
in_channels_per_group
!=
int
(
in_channels_per_group
):
raise
ValueError
(
f
'Input channels per group is found to be a non-integer:
{
in_channels_per_group
}
'
)
if
inputs
.
size
(
1
)
%
in_channels_per_group
!=
0
:
raise
RuntimeError
(
f
'Input channels must be divisible by in_channels_per_group, but the input shape is
{
inputs
.
size
()
}
, '
f
'while in_channels_per_group =
{
in_channels_per_group
}
'
)
# Compute sliced weights and groups (as an integer)
weight
=
_S
(
weight
)[:,
:
int
(
in_channels_per_group
)]
groups
=
inputs
.
size
(
1
)
//
int
(
in_channels_per_group
)
# slice center
if
isinstance
(
kernel_size
,
dict
):
# If kernel size is a dict, ignore choices in padding.
if
isinstance
(
self
.
padding
,
str
):
raise
ValueError
(
f
'Use "
{
self
.
padding
}
" in padding is not supported.'
)
padding
=
self
.
padding
# max padding, must be a tuple
kernel_a
,
kernel_b
=
self
.
_to_tuple
(
kernel_size
)
kernel_a_
,
kernel_b_
=
_W
(
kernel_a
),
_W
(
kernel_b
)
max_kernel_a
,
max_kernel_b
=
self
.
kernel_size
# self.kernel_size must be a tuple
kernel_a_left
,
kernel_b_top
=
(
max_kernel_a
-
kernel_a_
)
//
2
,
(
max_kernel_b
-
kernel_b_
)
//
2
weight
=
_S
(
weight
)[:,
:,
kernel_a_left
:
kernel_a_left
+
kernel_a_
,
kernel_b_top
:
kernel_b_top
+
kernel_b_
]
bias
=
_S
(
self
.
bias
)[:
out_channels_
]
if
self
.
bias
is
not
None
else
None
# The rest parameters only need to be converted to tuple
stride_
=
self
.
_to_tuple
(
stride
)
dilation_
=
self
.
_to_tuple
(
dilation
)
if
self
.
padding_mode
!=
'zeros'
:
return
F
.
conv2d
(
F
.
pad
(
inputs
,
self
.
_reversed_padding_repeated_twice
,
mode
=
self
.
padding_mode
),
weight
,
bias
,
stride_
,
(
0
,
0
),
dilation_
,
groups
)
return
F
.
conv2d
(
inputs
,
weight
,
bias
,
stride_
,
cast
(
'int | tuple'
,
padding
),
dilation_
,
groups
)
class
MixedBatchNorm2d
(
MixedOperation
,
nn
.
BatchNorm2d
):
"""
Mixed BatchNorm2d operation.
Supported arguments are:
- ``num_features``
- ``eps`` (only supported in path sampling)
- ``momentum`` (only supported in path sampling)
For path-sampling, prefix of ``weight``, ``bias``, ``running_mean`` and ``running_var``
are sliced. For weighted cases, the maximum ``num_features`` is used directly.
Momentum is required to be float.
PyTorch BatchNorm supports a case where momentum can be none, which is not supported here.
"""
bound_type
=
retiarii_nn
.
BatchNorm2d
argument_list
=
[
'num_features'
,
'eps'
,
'momentum'
]
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
num_features
:
int_or_int_dict
,
eps
:
float
,
momentum
:
float
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
eps
,
momentum
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'eps and momentum'
,
'BatchNorm2d'
))
if
isinstance
(
num_features
,
dict
):
num_features
=
self
.
num_features
weight
,
bias
=
self
.
weight
,
self
.
bias
running_mean
,
running_var
=
self
.
running_mean
,
self
.
running_var
if
num_features
<
self
.
num_features
:
weight
=
weight
[:
num_features
]
bias
=
bias
[:
num_features
]
if
running_mean
is
not
None
:
running_mean
=
running_mean
[:
num_features
]
if
running_var
is
not
None
:
running_var
=
running_var
[:
num_features
]
if
self
.
training
:
bn_training
=
True
else
:
bn_training
=
(
running_mean
is
None
)
and
(
running_var
is
None
)
return
F
.
batch_norm
(
inputs
,
# If buffers are not to be tracked, ensure that they won't be updated
running_mean
if
not
self
.
training
or
self
.
track_running_stats
else
None
,
running_var
if
not
self
.
training
or
self
.
track_running_stats
else
None
,
weight
,
bias
,
bn_training
,
momentum
,
# originally exponential_average_factor in pytorch code
eps
,
)
class
MixedLayerNorm
(
MixedOperation
,
nn
.
LayerNorm
):
"""
Mixed LayerNorm operation.
Supported arguments are:
- ``normalized_shape``
- ``eps`` (only supported in path sampling)
For path-sampling, prefix of ``weight`` and ``bias`` are sliced.
For weighted cases, the maximum ``normalized_shape`` is used directly.
eps is required to be float.
"""
bound_type
=
retiarii_nn
.
LayerNorm
argument_list
=
[
'normalized_shape'
,
'eps'
]
@
staticmethod
def
_to_tuple
(
value
:
scalar_or_scalar_dict
[
Any
])
->
tuple
[
Any
,
Any
]:
if
not
isinstance
(
value
,
tuple
):
return
(
value
,
value
)
return
value
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
if
name
not
in
[
'normalized_shape'
]:
raise
NotImplementedError
(
f
'Unsupported value choice on argument:
{
name
}
'
)
all_sizes
=
set
(
traverse_all_options
(
value_choice
))
if
any
(
isinstance
(
sz
,
(
tuple
,
list
))
for
sz
in
all_sizes
):
# transpose
all_sizes
=
list
(
zip
(
*
all_sizes
))
# maximum dim should be calculated on every dimension
return
(
max
(
self
.
_to_tuple
(
sz
))
for
sz
in
all_sizes
)
else
:
return
max
(
all_sizes
)
def
forward_with_args
(
self
,
normalized_shape
,
eps
:
float
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
eps
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'eps'
,
'LayerNorm'
))
if
isinstance
(
normalized_shape
,
dict
):
normalized_shape
=
self
.
normalized_shape
# make it as tuple
if
isinstance
(
normalized_shape
,
int
):
normalized_shape
=
(
normalized_shape
,
)
if
isinstance
(
self
.
normalized_shape
,
int
):
normalized_shape
=
(
self
.
normalized_shape
,
)
# slice all the normalized shape
indices
=
[
slice
(
0
,
min
(
i
,
j
))
for
i
,
j
in
zip
(
normalized_shape
,
self
.
normalized_shape
)]
# remove _S(*)
weight
=
self
.
weight
[
indices
]
if
self
.
weight
is
not
None
else
None
bias
=
self
.
bias
[
indices
]
if
self
.
bias
is
not
None
else
None
return
F
.
layer_norm
(
inputs
,
normalized_shape
,
weight
,
bias
,
eps
)
class
MixedMultiHeadAttention
(
MixedOperation
,
nn
.
MultiheadAttention
):
"""
Mixed multi-head attention.
Supported arguments are:
- ``embed_dim``
- ``num_heads`` (only supported in path sampling)
- ``kdim``
- ``vdim``
- ``dropout`` (only supported in path sampling)
At init, it constructs the largest possible Q, K, V dimension.
At forward, it slices the prefix to weight matrices according to the sampled value.
For ``in_proj_bias`` and ``in_proj_weight``, three parts will be sliced and concatenated together:
``[0, embed_dim)``, ``[max_embed_dim, max_embed_dim + embed_dim)``,
``[max_embed_dim * 2, max_embed_dim * 2 + embed_dim)``.
Warnings
----------
All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``.
"""
bound_type
=
retiarii_nn
.
MultiheadAttention
argument_list
=
[
'embed_dim'
,
'num_heads'
,
'kdim'
,
'vdim'
,
'dropout'
]
def
__post_init__
(
self
):
# sometimes super-class believes qkv have the same embed_dim.
# but actually they do not, because we can have dynamic (mutable) kdim/vdim.
_qkv_same_embed_dim
=
True
for
dimension
in
[
'kdim'
,
'vdim'
]:
if
self
.
init_arguments
[
dimension
]
is
None
:
# must follow embed_dim is this case
continue
if
getattr
(
self
,
dimension
)
==
self
.
embed_dim
and
\
(
dimension
in
self
.
mutable_arguments
or
'embed_dim'
in
self
.
mutable_arguments
):
_qkv_same_embed_dim
=
False
if
self
.
_qkv_same_embed_dim
and
not
_qkv_same_embed_dim
:
self
.
_qkv_same_embed_dim
=
_qkv_same_embed_dim
# adding back missing parameters
# factory_kwargs could be empty for legacy pytorch versions
factory_kwargs
=
{}
if
'device'
in
self
.
init_arguments
:
factory_kwargs
[
'device'
]
=
self
.
init_arguments
[
'device'
]
if
'dtype'
in
self
.
init_arguments
:
factory_kwargs
[
'dtype'
]
=
self
.
init_arguments
[
'dtype'
]
self
.
q_proj_weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
embed_dim
,
self
.
embed_dim
),
**
factory_kwargs
))
self
.
k_proj_weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
embed_dim
,
self
.
kdim
),
**
factory_kwargs
))
self
.
v_proj_weight
=
nn
.
Parameter
(
torch
.
empty
((
self
.
embed_dim
,
self
.
vdim
),
**
factory_kwargs
))
self
.
register_parameter
(
'in_proj_weight'
,
None
)
# reset parameters
nn
.
init
.
xavier_uniform_
(
self
.
q_proj_weight
)
nn
.
init
.
xavier_uniform_
(
self
.
k_proj_weight
)
nn
.
init
.
xavier_uniform_
(
self
.
v_proj_weight
)
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
_to_proj_slice
(
self
,
embed_dim
:
_W
)
->
list
[
slice
]:
# slice three parts, corresponding to q, k, v respectively
return
[
slice
(
embed_dim
),
slice
(
self
.
embed_dim
,
self
.
embed_dim
+
embed_dim
),
slice
(
self
.
embed_dim
*
2
,
self
.
embed_dim
*
2
+
embed_dim
)
]
def
forward_with_args
(
self
,
embed_dim
:
int_or_int_dict
,
num_heads
:
int
,
kdim
:
int_or_int_dict
|
None
,
vdim
:
int_or_int_dict
|
None
,
dropout
:
float
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_padding_mask
:
torch
.
Tensor
|
None
=
None
,
need_weights
:
bool
=
True
,
attn_mask
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
num_heads
,
dropout
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'num_heads and dropout'
,
'MultiHeadAttention'
))
# by default, kdim, vdim can be none
if
kdim
is
None
:
kdim
=
embed_dim
if
vdim
is
None
:
vdim
=
embed_dim
qkv_same_embed_dim
=
kdim
==
embed_dim
and
vdim
==
embed_dim
if
getattr
(
self
,
'batch_first'
,
False
):
# for backward compatibility: v1.7 doesn't have batch_first
query
,
key
,
value
=
[
x
.
transpose
(
1
,
0
)
for
x
in
(
query
,
key
,
value
)]
if
isinstance
(
embed_dim
,
dict
):
used_embed_dim
=
self
.
embed_dim
else
:
used_embed_dim
=
embed_dim
embed_dim_
=
_W
(
embed_dim
)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias
:
Tensor
|
None
=
None
in_proj_weight
:
Tensor
|
None
=
None
if
self
.
in_proj_bias
is
not
None
:
in_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
in_proj_bias
))[
self
.
_to_proj_slice
(
embed_dim_
)]
if
self
.
in_proj_weight
is
not
None
:
in_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
in_proj_weight
))[
self
.
_to_proj_slice
(
embed_dim_
),
:
embed_dim_
]
bias_k
=
_S
(
cast
(
Tensor
,
self
.
bias_k
))[:,
:,
:
embed_dim_
]
if
self
.
bias_k
is
not
None
else
None
bias_v
=
_S
(
cast
(
Tensor
,
self
.
bias_v
))[:,
:,
:
embed_dim_
]
if
self
.
bias_v
is
not
None
else
None
out_proj_weight
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
weight
))[:
embed_dim_
,
:
embed_dim_
]
out_proj_bias
=
_S
(
cast
(
Tensor
,
self
.
out_proj
.
bias
))[:
embed_dim_
]
if
self
.
out_proj
.
bias
is
not
None
else
None
if
not
qkv_same_embed_dim
:
q_proj
=
_S
(
cast
(
Tensor
,
self
.
q_proj_weight
))[:
embed_dim_
,
:
embed_dim_
]
k_proj
=
_S
(
cast
(
Tensor
,
self
.
k_proj_weight
))[:
embed_dim_
]
k_proj
=
_S
(
k_proj
)[:,
:
_W
(
kdim
)]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim_
]
v_proj
=
_S
(
v_proj
)[:,
:
_W
(
vdim
)]
# The rest part is basically same as pytorch
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
),
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
),
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
,
use_separate_proj_weight
=
True
,
q_proj_weight
=
q_proj
,
k_proj_weight
=
k_proj
,
v_proj_weight
=
v_proj
)
else
:
# Cast tensor here because of a bug in pytorch stub
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
),
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
),
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
)
if
getattr
(
self
,
'batch_first'
,
False
):
# backward compatibility
return
attn_output
.
transpose
(
1
,
0
),
attn_output_weights
else
:
return
attn_output
,
attn_output_weights
NATIVE_MIXED_OPERATIONS
:
list
[
Type
[
MixedOperation
]]
=
[
MixedLinear
,
MixedConv2d
,
MixedBatchNorm2d
,
MixedLayerNorm
,
MixedMultiHeadAttention
,
]
# For the supported operations to be properly rendered in documentation
NATIVE_SUPPORTED_OP_NAMES
:
list
[
str
]
=
[
op
.
bound_type
.
__name__
for
op
in
NATIVE_MIXED_OPERATIONS
]
from
nni.nas.oneshot.pytorch.supermodule.operation
import
*
nni/retiarii/oneshot/pytorch/supermodule/proxyless.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Implementation of ProxylessNAS: a hyrbid approach between differentiable and sampling.
The support remains limited. Known limitations include:
# pylint: disable=wildcard-import,unused-wildcard-import
- No support for multiple arguments in forward.
- No support for mixed-operation (value choice).
- The code contains duplicates. Needs refactor.
"""
from
__future__
import
annotations
from
typing
import
cast
import
torch
import
torch.nn
as
nn
from
.differentiable
import
DifferentiableMixedLayer
,
DifferentiableMixedInput
__all__
=
[
'ProxylessMixedLayer'
,
'ProxylessMixedInput'
]
class
_ArchGradientFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
binary_gates
,
run_func
,
backward_func
):
ctx
.
run_func
=
run_func
ctx
.
backward_func
=
backward_func
detached_x
=
x
.
detach
()
detached_x
.
requires_grad
=
x
.
requires_grad
with
torch
.
enable_grad
():
output
=
run_func
(
detached_x
)
ctx
.
save_for_backward
(
detached_x
,
output
)
return
output
.
data
@
staticmethod
def
backward
(
ctx
,
grad_output
):
detached_x
,
output
=
ctx
.
saved_tensors
grad_x
=
torch
.
autograd
.
grad
(
output
,
detached_x
,
grad_output
,
only_inputs
=
True
)
# compute gradients w.r.t. binary_gates
binary_grads
=
ctx
.
backward_func
(
detached_x
.
data
,
output
.
data
,
grad_output
.
data
)
return
grad_x
[
0
],
binary_grads
,
None
,
None
class
ProxylessMixedLayer
(
DifferentiableMixedLayer
):
"""Proxyless version of differentiable mixed layer.
It resamples a single-path every time, rather than go through the softmax.
"""
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
paths
:
list
[
tuple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
paths
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
len
(
paths
))
*
1E-3
)
# like sampling-based methods, it has a ``_sampled``.
self
.
_sampled
:
str
|
None
=
None
self
.
_sample_idx
:
int
|
None
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
def
run_function
(
ops
,
active_id
,
**
kwargs
):
def
forward
(
_x
):
return
ops
[
active_id
](
_x
,
**
kwargs
)
return
forward
def
backward_function
(
ops
,
active_id
,
binary_gates
,
**
kwargs
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
len
(
ops
)):
if
k
!=
active_id
:
out_k
=
ops
[
k
](
_x
.
data
,
**
kwargs
)
else
:
out_k
=
_output
.
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
assert
len
(
args
)
==
1
,
'ProxylessMixedLayer only supports exactly one input argument.'
x
=
args
[
0
]
assert
self
.
_sampled
is
not
None
,
'Need to call resample() before running fprop.'
list_ops
=
[
getattr
(
self
,
op
)
for
op
in
self
.
op_names
]
return
_ArchGradientFunction
.
apply
(
x
,
self
.
_binary_gates
,
run_function
(
list_ops
,
self
.
_sample_idx
,
**
kwargs
),
backward_function
(
list_ops
,
self
.
_sample_idx
,
self
.
_binary_gates
,
**
kwargs
)
)
def
resample
(
self
,
memo
):
"""Sample one path based on alpha if label is not found in memo."""
if
self
.
label
in
memo
:
self
.
_sampled
=
memo
[
self
.
label
]
self
.
_sample_idx
=
self
.
op_names
.
index
(
self
.
_sampled
)
else
:
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
self
.
_sample_idx
=
int
(
torch
.
multinomial
(
probs
,
1
)[
0
].
item
())
self
.
_sampled
=
self
.
op_names
[
self
.
_sample_idx
]
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
self
.
_sample_idx
]
=
1.0
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
op_names
[
int
(
torch
.
argmax
(
self
.
_arch_alpha
).
item
())]}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
for
i
in
range
(
len
(
self
.
_arch_alpha
)):
for
j
in
range
(
len
(
self
.
_arch_alpha
)):
self
.
_arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
class
ProxylessMixedInput
(
DifferentiableMixedInput
):
"""Proxyless version of differentiable input choice.
See :class:`ProxylessLayerChoice` for implementation details.
"""
_arch_parameter_names
=
[
'_arch_alpha'
,
'_binary_gates'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
(
n_candidates
,
n_chosen
,
alpha
,
softmax
,
label
)
self
.
_binary_gates
=
nn
.
Parameter
(
torch
.
randn
(
n_candidates
)
*
1E-3
)
self
.
_sampled
:
int
|
None
=
None
def
forward
(
self
,
inputs
):
def
run_function
(
active_sample
):
return
lambda
x
:
x
[
active_sample
]
def
backward_function
(
binary_gates
):
def
backward
(
_x
,
_output
,
grad_output
):
binary_grads
=
torch
.
zeros_like
(
binary_gates
.
data
)
with
torch
.
no_grad
():
for
k
in
range
(
self
.
n_candidates
):
out_k
=
_x
[
k
].
data
grad_k
=
torch
.
sum
(
out_k
*
grad_output
)
binary_grads
[
k
]
=
grad_k
return
binary_grads
return
backward
inputs
=
torch
.
stack
(
inputs
,
0
)
assert
self
.
_sampled
is
not
None
,
'Need to call resample() before running fprop.'
return
_ArchGradientFunction
.
apply
(
inputs
,
self
.
_binary_gates
,
run_function
(
self
.
_sampled
),
backward_function
(
self
.
_binary_gates
)
)
def
resample
(
self
,
memo
):
"""Sample one path based on alpha if label is not found in memo."""
if
self
.
label
in
memo
:
self
.
_sampled
=
memo
[
self
.
label
]
else
:
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
sample
=
torch
.
multinomial
(
probs
,
1
)[
0
].
item
()
self
.
_sampled
=
int
(
sample
)
# set binary gates
with
torch
.
no_grad
():
self
.
_binary_gates
.
zero_
()
self
.
_binary_gates
.
grad
=
torch
.
zeros_like
(
self
.
_binary_gates
.
data
)
self
.
_binary_gates
.
data
[
cast
(
int
,
self
.
_sampled
)]
=
1.0
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Chose the argmax if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
torch
.
argmax
(
self
.
_arch_alpha
).
item
()}
def
finalize_grad
(
self
):
binary_grads
=
self
.
_binary_gates
.
grad
assert
binary_grads
is
not
None
with
torch
.
no_grad
():
if
self
.
_arch_alpha
.
grad
is
None
:
self
.
_arch_alpha
.
grad
=
torch
.
zeros_like
(
self
.
_arch_alpha
.
data
)
probs
=
self
.
_softmax
(
self
.
_arch_alpha
)
for
i
in
range
(
self
.
n_candidates
):
for
j
in
range
(
self
.
n_candidates
):
self
.
_arch_alpha
.
grad
[
i
]
+=
binary_grads
[
j
]
*
probs
[
j
]
*
(
int
(
i
==
j
)
-
probs
[
i
])
from
nni.nas.oneshot.pytorch.supermodule.proxyless
import
*
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import
copy
import
random
from
typing
import
Any
,
List
,
Dict
,
Sequence
,
cast
import
torch
import
torch.nn
as
nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
Repeat
,
ChoiceOf
,
Cell
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.retiarii.nn.pytorch.cell
import
CellOpFactory
,
create_cell_op_candidates
,
preprocess_cell_inputs
from
.base
import
BaseSuperNetModule
from
._valuechoice_utils
import
evaluate_value_choice_with_dict
,
dedup_inner_choices
,
weighted_sum
from
.operation
import
MixedOperationSamplingPolicy
,
MixedOperation
__all__
=
[
'PathSamplingLayer'
,
'PathSamplingInput'
,
'PathSamplingRepeat'
,
'PathSamplingCell'
,
'MixedOpPathSamplingPolicy'
]
class
PathSamplingLayer
(
BaseSuperNetModule
):
"""
Mixed layer, in which fprop is decided by exactly one inner layer or sum of multiple (sampled) layers.
If multiple modules are selected, the result will be summed and returned.
Attributes
----------
_sampled : int or list of str
Sampled module indices.
label : str
Name of the choice.
"""
def
__init__
(
self
,
paths
:
list
[
tuple
[
str
,
nn
.
Module
]],
label
:
str
):
super
().
__init__
()
self
.
op_names
=
[]
for
name
,
module
in
paths
:
self
.
add_module
(
name
,
module
)
self
.
op_names
.
append
(
name
)
assert
self
.
op_names
,
'There has to be at least one op to choose from.'
self
.
_sampled
:
list
[
str
]
|
str
|
None
=
None
# sampled can be either a list of indices or an index
self
.
label
=
label
def
resample
(
self
,
memo
):
"""Random choose one path if label is not found in memo."""
if
self
.
label
in
memo
:
self
.
_sampled
=
memo
[
self
.
label
]
else
:
self
.
_sampled
=
random
.
choice
(
self
.
op_names
)
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Random choose one name if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
random
.
choice
(
self
.
op_names
)}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
self
.
op_names
,
(
self
.
label
,
),
True
,
size
=
len
(
self
.
op_names
))}
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
LayerChoice
):
return
cls
(
list
(
module
.
named_children
()),
module
.
label
)
def
reduction
(
self
,
items
:
list
[
Any
],
sampled
:
list
[
Any
]):
"""Override this to implement customized reduction."""
return
weighted_sum
(
items
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one path needs to be sampled before fprop.'
)
sampled
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
res
=
[
getattr
(
self
,
str
(
samp
))(
*
args
,
**
kwargs
)
for
samp
in
sampled
]
return
self
.
reduction
(
res
,
sampled
)
class
PathSamplingInput
(
BaseSuperNetModule
):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.
Attributes
----------
_sampled : int or list of int
Sampled input indices.
"""
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
,
reduction_type
:
str
,
label
:
str
):
super
().
__init__
()
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction_type
=
reduction_type
self
.
_sampled
:
list
[
int
]
|
int
|
None
=
None
self
.
label
=
label
def
_random_choose_n
(
self
):
sampling
=
list
(
range
(
self
.
n_candidates
))
random
.
shuffle
(
sampling
)
sampling
=
sorted
(
sampling
[:
self
.
n_chosen
])
if
len
(
sampling
)
==
1
:
return
sampling
[
0
]
else
:
return
sampling
def
resample
(
self
,
memo
):
"""Random choose one path / multiple paths if label is not found in memo.
If one path is selected, only one integer will be in ``self._sampled``.
If multiple paths are selected, a list will be in ``self._sampled``.
"""
if
self
.
label
in
memo
:
self
.
_sampled
=
memo
[
self
.
label
]
else
:
self
.
_sampled
=
self
.
_random_choose_n
()
return
{
self
.
label
:
self
.
_sampled
}
def
export
(
self
,
memo
):
"""Random choose one name if label isn't found in memo."""
if
self
.
label
in
memo
:
return
{}
# nothing new to export
return
{
self
.
label
:
self
.
_random_choose_n
()}
def
search_space_spec
(
self
):
return
{
self
.
label
:
ParameterSpec
(
self
.
label
,
'choice'
,
list
(
range
(
self
.
n_candidates
)),
(
self
.
label
,
),
True
,
size
=
self
.
n_candidates
,
chosen_size
=
self
.
n_chosen
)
}
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
InputChoice
):
if
module
.
reduction
not
in
[
'sum'
,
'mean'
,
'concat'
]:
raise
ValueError
(
'Only input choice of sum/mean/concat reduction is supported.'
)
if
module
.
n_chosen
is
None
:
raise
ValueError
(
'n_chosen is None is not supported yet.'
)
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
module
.
reduction
,
module
.
label
)
def
reduction
(
self
,
items
:
list
[
Any
],
sampled
:
list
[
Any
])
->
Any
:
"""Override this to implement customized reduction."""
if
len
(
items
)
==
1
:
return
items
[
0
]
else
:
if
self
.
reduction_type
==
'sum'
:
return
sum
(
items
)
elif
self
.
reduction_type
==
'mean'
:
return
sum
(
items
)
/
len
(
items
)
elif
self
.
reduction_type
==
'concat'
:
return
torch
.
cat
(
items
,
1
)
raise
ValueError
(
f
'Unsupported reduction type:
{
self
.
reduction_type
}
'
)
def
forward
(
self
,
input_tensors
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one path needs to be sampled before fprop.'
)
if
len
(
input_tensors
)
!=
self
.
n_candidates
:
raise
ValueError
(
f
'Expect
{
self
.
n_candidates
}
input tensors, found
{
len
(
input_tensors
)
}
.'
)
sampled
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
res
=
[
input_tensors
[
samp
]
for
samp
in
sampled
]
return
self
.
reduction
(
res
,
sampled
)
class
MixedOpPathSamplingPolicy
(
MixedOperationSamplingPolicy
):
"""Implementes the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices".
We sample the leaf nodes, and composits them into the values on arguments.
"""
def
__init__
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
],
mutate_kwargs
:
dict
[
str
,
Any
])
->
None
:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self
.
_sampled
:
dict
[
str
,
Any
]
|
None
=
None
def
resample
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""Random sample for each leaf value choice."""
result
=
{}
space_spec
=
operation
.
search_space_spec
()
for
label
in
space_spec
:
if
label
in
memo
:
result
[
label
]
=
memo
[
label
]
else
:
result
[
label
]
=
random
.
choice
(
space_spec
[
label
].
values
)
# composits to kwargs
# example: result = {"exp_ratio": 3}, self._sampled = {"in_channels": 48, "out_channels": 96}
self
.
_sampled
=
{}
for
key
,
value
in
operation
.
mutable_arguments
.
items
():
self
.
_sampled
[
key
]
=
evaluate_value_choice_with_dict
(
value
,
result
)
return
result
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""Export is also random for each leaf value choice."""
result
=
{}
space_spec
=
operation
.
search_space_spec
()
for
label
in
space_spec
:
if
label
not
in
memo
:
result
[
label
]
=
random
.
choice
(
space_spec
[
label
].
values
)
return
result
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
Any
:
# NOTE: we don't support sampling a list here.
if
self
.
_sampled
is
None
:
raise
ValueError
(
'Need to call resample() before running forward'
)
if
name
in
operation
.
mutable_arguments
:
return
self
.
_sampled
[
name
]
return
operation
.
init_arguments
[
name
]
class
PathSamplingRepeat
(
BaseSuperNetModule
):
"""
Implementaion of Repeat in a path-sampling supernet.
Samples one / some of the prefixes of the repeated blocks.
Attributes
----------
_sampled : int or list of int
Sampled depth.
"""
def
__init__
(
self
,
blocks
:
list
[
nn
.
Module
],
depth
:
ChoiceOf
[
int
]):
super
().
__init__
()
self
.
blocks
=
blocks
self
.
depth
=
depth
self
.
_space_spec
:
dict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
([
depth
])
self
.
_sampled
:
list
[
int
]
|
int
|
None
=
None
def
resample
(
self
,
memo
):
"""Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices."""
result
=
{}
for
label
in
self
.
_space_spec
:
if
label
in
memo
:
result
[
label
]
=
memo
[
label
]
else
:
result
[
label
]
=
random
.
choice
(
self
.
_space_spec
[
label
].
values
)
self
.
_sampled
=
evaluate_value_choice_with_dict
(
self
.
depth
,
result
)
return
result
def
export
(
self
,
memo
):
"""Random choose one if every choice not in memo."""
result
=
{}
for
label
in
self
.
_space_spec
:
if
label
not
in
memo
:
result
[
label
]
=
random
.
choice
(
self
.
_space_spec
[
label
].
values
)
return
result
def
search_space_spec
(
self
):
return
self
.
_space_spec
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
Repeat
)
and
isinstance
(
module
.
depth_choice
,
ValueChoiceX
):
# Only interesting when depth is mutable
return
cls
(
cast
(
List
[
nn
.
Module
],
module
.
blocks
),
module
.
depth_choice
)
def
reduction
(
self
,
items
:
list
[
Any
],
sampled
:
list
[
Any
]):
"""Override this to implement customized reduction."""
return
weighted_sum
(
items
)
def
forward
(
self
,
x
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one depth needs to be sampled before fprop.'
)
sampled
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
res
=
[]
for
cur_depth
,
block
in
enumerate
(
self
.
blocks
,
start
=
1
):
x
=
block
(
x
)
if
cur_depth
in
sampled
:
res
.
append
(
x
)
if
not
any
(
d
>
cur_depth
for
d
in
sampled
):
break
return
self
.
reduction
(
res
,
sampled
)
class
PathSamplingCell
(
BaseSuperNetModule
):
"""The implementation of super-net cell follows `DARTS <https://github.com/quark0/darts>`__.
When ``factory_used`` is true, it reconstructs the cell for every possible combination of operation and input index,
because for different input index, the cell factory could instantiate different operations (e.g., with different stride).
On export, we first have best (operation, input) pairs, the select the best ``num_ops_per_node``.
``loose_end`` is not supported yet, because it will cause more problems (e.g., shape mismatch).
We assumes ``loose_end`` to be ``all`` regardless of its configuration.
A supernet cell can't slim its own weight to fit into a sub network, which is also a known issue.
"""
def
__init__
(
self
,
op_factory
:
list
[
CellOpFactory
]
|
dict
[
str
,
CellOpFactory
],
num_nodes
:
int
,
num_ops_per_node
:
int
,
num_predecessors
:
int
,
preprocessor
:
Any
,
postprocessor
:
Any
,
concat_dim
:
int
,
memo
:
dict
,
# although not used here, useful in subclass
mutate_kwargs
:
dict
,
# same as memo
label
:
str
,
):
super
().
__init__
()
self
.
num_nodes
=
num_nodes
self
.
num_ops_per_node
=
num_ops_per_node
self
.
num_predecessors
=
num_predecessors
self
.
preprocessor
=
preprocessor
self
.
ops
=
nn
.
ModuleList
()
self
.
postprocessor
=
postprocessor
self
.
concat_dim
=
concat_dim
self
.
op_names
:
list
[
str
]
=
cast
(
List
[
str
],
None
)
self
.
output_node_indices
=
list
(
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
))
# Create a fully-connected graph.
# Each edge is a ModuleDict with op candidates.
# Can not reuse LayerChoice here, because the spec, resample, export all need to be customized.
# InputChoice is implicit in this graph.
for
i
in
self
.
output_node_indices
:
self
.
ops
.
append
(
nn
.
ModuleList
())
for
k
in
range
(
i
+
self
.
num_predecessors
):
# Second argument in (i, **0**, k) is always 0.
# One-shot strategy can't handle the cases where op spec is dependent on `op_index`.
ops
,
_
=
create_cell_op_candidates
(
op_factory
,
i
,
0
,
k
)
self
.
op_names
=
list
(
ops
.
keys
())
cast
(
nn
.
ModuleList
,
self
.
ops
[
-
1
]).
append
(
nn
.
ModuleDict
(
ops
))
self
.
label
=
label
self
.
_sampled
:
dict
[
str
,
str
|
int
]
=
{}
def
search_space_spec
(
self
)
->
dict
[
str
,
ParameterSpec
]:
# TODO: Recreating the space here.
# The spec should be moved to definition of Cell itself.
space_spec
=
{}
for
i
in
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
):
for
k
in
range
(
self
.
num_ops_per_node
):
op_label
=
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
input_label
=
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
space_spec
[
op_label
]
=
ParameterSpec
(
op_label
,
'choice'
,
self
.
op_names
,
(
op_label
,),
True
,
size
=
len
(
self
.
op_names
))
space_spec
[
input_label
]
=
ParameterSpec
(
input_label
,
'choice'
,
list
(
range
(
i
)),
(
input_label
,
),
True
,
size
=
i
)
return
space_spec
def
resample
(
self
,
memo
):
"""Random choose one path if label is not found in memo."""
self
.
_sampled
=
{}
new_sampled
=
{}
for
label
,
param_spec
in
self
.
search_space_spec
().
items
():
if
label
in
memo
:
assert
not
isinstance
(
memo
[
label
],
list
),
'Multi-path sampling is currently unsupported on cell.'
self
.
_sampled
[
label
]
=
memo
[
label
]
else
:
self
.
_sampled
[
label
]
=
new_sampled
[
label
]
=
random
.
choice
(
param_spec
.
values
)
return
new_sampled
def
export
(
self
,
memo
):
"""Randomly choose one to export."""
return
self
.
resample
(
memo
)
def
forward
(
self
,
*
inputs
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]
|
torch
.
Tensor
:
processed_inputs
:
List
[
torch
.
Tensor
]
=
preprocess_cell_inputs
(
self
.
num_predecessors
,
*
inputs
)
states
:
List
[
torch
.
Tensor
]
=
self
.
preprocessor
(
processed_inputs
)
for
i
,
ops
in
enumerate
(
cast
(
Sequence
[
Sequence
[
Dict
[
str
,
nn
.
Module
]]],
self
.
ops
),
start
=
self
.
num_predecessors
):
current_state
=
[]
for
k
in
range
(
self
.
num_ops_per_node
):
# Select op list based on the input chosen
input_index
=
self
.
_sampled
[
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
]
op_candidates
=
ops
[
cast
(
int
,
input_index
)]
# Select op from op list based on the op chosen
op_index
=
self
.
_sampled
[
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
]
op
=
op_candidates
[
cast
(
str
,
op_index
)]
current_state
.
append
(
op
(
states
[
cast
(
int
,
input_index
)]))
states
.
append
(
sum
(
current_state
))
# type: ignore
# Always merge all
this_cell
=
torch
.
cat
(
states
[
self
.
num_predecessors
:],
self
.
concat_dim
)
return
self
.
postprocessor
(
this_cell
,
processed_inputs
)
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
Cell
):
op_factory
=
None
# not all the cells need to be replaced
if
module
.
op_candidates_factory
is
not
None
:
op_factory
=
module
.
op_candidates_factory
assert
isinstance
(
op_factory
,
list
)
or
isinstance
(
op_factory
,
dict
),
\
'Only support op_factory of type list or dict.'
elif
module
.
merge_op
==
'loose_end'
:
op_candidates_lc
=
module
.
ops
[
-
1
][
-
1
]
# type: ignore
assert
isinstance
(
op_candidates_lc
,
LayerChoice
)
op_factory
=
{
# create a factory
name
:
lambda
_
,
__
,
___
:
copy
.
deepcopy
(
op_candidates_lc
[
name
])
for
name
in
op_candidates_lc
.
names
}
if
op_factory
is
not
None
:
return
cls
(
op_factory
,
module
.
num_nodes
,
module
.
num_ops_per_node
,
module
.
num_predecessors
,
module
.
preprocessor
,
module
.
postprocessor
,
module
.
concat_dim
,
memo
,
mutate_kwargs
,
module
.
label
)
from
nni.nas.oneshot.pytorch.supermodule.sampling
import
*
nni/retiarii/oneshot/pytorch/utils.py
View file @
a0fd0036
...
...
@@ -12,7 +12,6 @@ import torch
from
torch.utils.data
import
DataLoader
,
Dataset
import
nni.retiarii.nn.pytorch
as
nn
from
nni.nas.pytorch.mutables
import
InputChoice
,
LayerChoice
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -163,7 +162,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
LayerChoice
,
nn
.
LayerChoice
)
,
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
nn
.
LayerChoice
,
modules
)
def
replace_input_choice
(
root_module
,
init_fn
,
modules
=
None
):
...
...
@@ -184,7 +183,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return
_replace_module_with_type
(
root_module
,
init_fn
,
(
InputChoice
,
nn
.
InputChoice
)
,
modules
)
return
_replace_module_with_type
(
root_module
,
init_fn
,
nn
.
InputChoice
,
modules
)
class
InterleavedTrainValDataLoader
(
DataLoader
):
...
...
nni/retiarii/operation.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
cast
)
# pylint: disable=wildcard-import,unused-wildcard-import
from
.
import
debug_configs
__all__
=
[
'Operation'
,
'Cell'
]
def
_convert_name
(
name
:
str
)
->
str
:
"""
Convert the names using separator '.' to valid variable name in code
"""
return
name
.
replace
(
'.'
,
'__'
)
class
Operation
:
"""
Calculation logic of a graph node.
The constructor is private. Use `Operation.new()` to create operation object.
`Operation` is a naive record.
Do not "mutate" its attributes or store information relate to specific node.
All complex logic should be implemented in `Node` class.
Attributes
----------
type
Operation type name (e.g. Conv2D).
If it starts with underscore, the "operation" is a special one (e.g. subgraph, input/output).
parameters
Arbitrary key-value parameters (e.g. kernel_size).
"""
io_names
:
List
[
str
]
=
[]
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{},
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
{}):
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
self
.
type
:
str
=
type_name
self
.
parameters
:
Dict
[
str
,
Any
]
=
parameters
self
.
attributes
:
Dict
[
str
,
Any
]
=
attributes
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
NotImplementedError
()
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
raise
NotImplementedError
()
def
_to_class_name
(
self
)
->
str
:
raise
NotImplementedError
()
def
__bool__
(
self
)
->
bool
:
return
True
@
staticmethod
def
new
(
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
),
cell_name
:
str
=
cast
(
str
,
None
),
attributes
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
))
->
'Operation'
:
parameters
=
parameters
or
{}
attributes
=
attributes
or
{}
if
type_name
==
'_cell'
:
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return
Cell
(
cell_name
,
parameters
)
else
:
if
debug_configs
.
framework
.
lower
()
in
(
'torch'
,
'pytorch'
):
from
.operation_def
import
torch_op_def
# pylint: disable=unused-import
cls
=
PyTorchOperation
.
_find_subclass
(
type_name
)
elif
debug_configs
.
framework
.
lower
()
in
(
'tf'
,
'tensorflow'
):
from
.operation_def
import
tf_op_def
# pylint: disable=unused-import
cls
=
TensorFlowOperation
.
_find_subclass
(
type_name
)
else
:
raise
ValueError
(
f
'Unsupported framework:
{
debug_configs
.
framework
}
'
)
return
cls
(
type_name
,
parameters
,
_internal
=
True
,
attributes
=
attributes
)
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
for
subclass
in
cls
.
__subclasses__
():
if
subclass
.
__name__
==
subclass_name
:
return
subclass
return
cls
def
__repr__
(
self
):
type_name
=
type
(
self
).
__name__
args
=
[
f
'
{
key
}
=
{
repr
(
value
)
}
'
for
key
,
value
in
self
.
parameters
.
items
()]
if
type_name
!=
self
.
type
:
args
=
[
f
'type="
{
self
.
type
}
"'
]
+
args
return
f
'
{
type_name
}
(
{
", "
.
join
(
args
)
}
)'
def
__eq__
(
self
,
other
):
return
type
(
other
)
is
type
(
self
)
and
other
.
type
==
self
.
type
and
other
.
parameters
==
self
.
parameters
class
PyTorchOperation
(
Operation
):
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
if
cls
.
to_class_name
(
subclass_name
)
is
not
None
:
subclass_name
=
'ModuleOperator'
if
cls
.
is_functional
(
subclass_name
):
subclass_name
=
'FunctionalOperator'
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_ori_type_name'
)
and
\
subclass_name
in
cast
(
Any
,
subclass
).
_ori_type_name
:
return
subclass
for
subclass
in
cls
.
__subclasses__
():
if
hasattr
(
subclass
,
'_artificial_op_name'
)
and
\
subclass_name
in
cast
(
Any
,
subclass
).
_artificial_op_name
:
return
subclass
return
cls
@
classmethod
def
to_class_name
(
cls
,
type_name
)
->
Optional
[
str
]:
if
type_name
.
startswith
(
'__torch__.'
):
return
type_name
[
len
(
'__torch__.'
):]
elif
type_name
.
startswith
(
'__mutated__.'
):
return
type_name
[
len
(
'__mutated__.'
):]
else
:
return
None
@
classmethod
def
is_functional
(
cls
,
type_name
)
->
bool
:
return
type_name
.
startswith
(
'Function.'
)
def
_to_class_name
(
self
)
->
Optional
[
str
]:
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):]
elif
self
.
type
.
startswith
(
'__mutated__.'
):
return
self
.
type
[
len
(
'__mutated__.'
):]
else
:
return
None
def
get_import_pkg
(
self
)
->
Optional
[
str
]:
if
self
.
type
.
startswith
(
'__torch__.'
):
return
self
.
type
[
len
(
'__torch__.'
):].
split
(
'.'
)[
0
]
elif
self
.
type
.
startswith
(
'__mutated__.'
):
return
self
.
type
[
len
(
'__mutated__.'
):].
split
(
'.'
)[
0
]
else
:
return
None
def
to_init_code
(
self
,
field
:
str
)
->
Optional
[
str
]:
if
self
.
_to_class_name
()
is
not
None
:
assert
'positional_args'
not
in
self
.
parameters
kw_params
=
', '
.
join
(
f
'
{
key
}
=
{
repr
(
value
)
}
'
for
key
,
value
in
self
.
parameters
.
items
())
return
f
'self.
{
field
}
=
{
self
.
_to_class_name
()
}
(
{
kw_params
}
)'
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
"""
Parameters
----------
field : str
the name of member submodule
output : str
the output name (lvalue) of this line of code
inputs : List[str]
variables used in this line of code
inputs_value : List[Any]
some variables are actually constant, their real values are recorded in ```inputs_value```.
if not constant, we simply put None at the corresponding index
Returns
-------
str
generated code line
"""
if
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
class
TensorFlowOperation
(
Operation
):
def
_to_class_name
(
self
)
->
str
:
return
'K.layers.'
+
self
.
type
class
Cell
(
PyTorchOperation
):
"""
TODO: this is pytorch cell
An operation reference to a subgraph.
Example code:
```
def __init__(...):
...
self.cell = CustomCell(...)
self.relu = K.layers.ReLU()
...
def forward(...):
...
x = self.cell(x)
...
```
In above example, node `self.cell`'s operation is `Cell(cell_name='CustomCell')`.
For comparison, `self.relu`'s operation is `Operation(type='ReLU')`.
TODO: parameters of subgraph (see `Node` class)
Attributes
----------
type
Always "_cell".
parameters
A dict with only one item; the key is "cell" and the value is cell's name.
framework
No real usage. Exists for compatibility with base class.
"""
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
),
attributes
:
Dict
[
str
,
Any
]
=
cast
(
Dict
[
str
,
Any
],
None
)):
self
.
type
=
'_cell'
self
.
cell_name
=
cell_name
self
.
parameters
=
parameters
or
{}
self
.
attributes
=
attributes
or
{}
def
_to_class_name
(
self
):
# TODO: ugly, think about how to refactor this part
return
_convert_name
(
self
.
cell_name
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
_IOPseudoOperation
(
Operation
):
"""
This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def
__init__
(
self
,
type_name
:
str
,
io_names
:
List
[
str
]
=
cast
(
List
[
str
],
None
)):
assert
type_name
.
startswith
(
'_'
)
super
(
_IOPseudoOperation
,
self
).
__init__
(
type_name
,
{},
True
)
self
.
io_names
=
io_names
def
to_init_code
(
self
,
field
:
str
)
->
str
:
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
raise
ValueError
(
f
'Cannot generate code for pseudo operation "
{
self
.
type
}
"'
)
def
__bool__
(
self
)
->
bool
:
return
False
from
nni.nas.execution.common.graph_op
import
*
nni/retiarii/operation_def/tf_op_def.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
..operation
import
TensorFlowOperation
# pylint: disable=wildcard-import,unused-wildcard-import
class
Conv2D
(
TensorFlowOperation
):
def
__init__
(
self
,
type_name
,
parameters
,
_internal
,
attributes
=
None
):
if
'padding'
not
in
parameters
:
parameters
[
'padding'
]
=
'same'
super
().
__init__
(
type_name
,
parameters
,
_internal
)
from
nni.nas.execution.tensorflow.op_def
import
*
nni/retiarii/operation_def/torch_op_def.py
View file @
a0fd0036
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
# pylint: disable=wildcard-import,unused-wildcard-import
from
typing
import
(
Any
,
Dict
,
List
)
import
torch
import
torch.nn.functional
as
nn_functional
from
..operation
import
PyTorchOperation
mem_format
=
[
'torch.contiguous_format'
,
# 0
'torch.preserve_format'
,
# 1
'torch.channels_last'
,
# 2
]
# this snippet is copied from torch/onnx/symbolic_helper.py,
# the original definition is in c10/core/ScalarType.h
# This indicates each scalar type's corresponding
scalar_type_to_pytorch_type
=
[
'torch.uint8'
,
# 0
'torch.int8'
,
# 1
'torch.short'
,
# 2
'torch.int'
,
# 3
'torch.int64'
,
# 4
'torch.half'
,
# 5
'torch.float'
,
# 6
'torch.double'
,
# 7
'torch.complex32'
,
# 8
'torch.complex64'
,
# 9
'torch.complex128'
,
# 10
'torch.bool'
,
# 11
]
class
NoOpIdentity
(
PyTorchOperation
):
"""
this operator type is added by us
"""
_ori_type_name
=
[
'noop_identity'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
", "
.
join
(
inputs
)
}
'
class
ModuleOperator
(
PyTorchOperation
):
_ori_type_name
=
[
'ModuleOperator'
,
'shared'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
class
FunctionalOperator
(
PyTorchOperation
):
_ori_type_name
=
[
'FunctionalOperator'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
func_name
=
self
.
type
[
len
(
'Function.'
):]
if
not
hasattr
(
nn_functional
,
func_name
):
raise
RuntimeError
(
'For now, we only support calling independent functions from `torch.nn.functional`, '
f
'
{
func_name
}
is not in it.'
)
return
f
'
{
output
}
= F.
{
func_name
}
(
{
", "
.
join
(
inputs
)
}
)'
class
PrimConstant
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::Constant'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if
self
.
parameters
[
'type'
]
in
[
'None'
,
'NoneType'
]:
return
f
'
{
output
}
= None'
elif
self
.
parameters
[
'type'
]
in
(
'int'
,
'float'
,
'bool'
,
'int[]'
):
# 'Long()' ???
return
f
'
{
output
}
=
{
self
.
parameters
[
"value"
]
}
'
elif
self
.
parameters
[
'type'
]
==
'str'
:
str_val
=
self
.
parameters
[
"value"
]
return
f
'
{
output
}
= "
{
str_val
}
"'
elif
self
.
parameters
[
'type'
]
==
'Device'
:
value
=
self
.
parameters
[
'value'
]
return
f
'
{
output
}
= torch.device("
{
value
}
")'
elif
self
.
parameters
[
'type'
]
in
(
'dict'
,
'list'
,
'tuple'
):
# TODO: prim::TupleIndex is not supported yet
return
f
'
{
output
}
=
{
repr
(
self
.
parameters
[
"value"
])
}
'
else
:
raise
RuntimeError
(
f
'unsupported type of prim::Constant:
{
self
.
parameters
[
"type"
]
}
'
)
class
PrimListConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
class
PrimListUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::ListUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
PrimTupleConstruct
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleConstruct'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= (
{
", "
.
join
(
inputs
)
}
)'
class
PrimTupleUnpack
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::TupleUnpack'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# have single output here, because the following code uses index to access the unpacked values
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
PrimGetAttr
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::GetAttr'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
self
.
parameters
[
'value'
]
is
not
None
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'value'
]
}
"
else
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
class
PrimUncheckedCast
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::unchecked_cast'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
class
SimpleMember
(
PyTorchOperation
):
_ori_type_name
=
[
'prim::is_cuda'
,
'prim::data'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
member_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
member_name
}
'
class
AtenContiguous
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::contiguous'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# defined in pytorch/c10/core/MemoryFormat.h
assert
inputs_value
is
not
None
and
inputs_value
[
1
]
in
[
0
,
1
,
2
]
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.contiguous(memory_format=
{
mem_format
[
inputs_value
[
1
]]
}
)'
class
AtenGetitem
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__getitem__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
class
AtenAppend
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::append'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
class
MergedSlice
(
PyTorchOperation
):
_ori_type_name
=
[
'MergedSlice'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
(
len
(
inputs
)
-
1
)
%
4
==
0
:
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
for
i
in
range
(
dim
):
slices
.
append
(
f
'
{
inputs
[
i
*
4
+
2
]
}
:
{
inputs
[
i
*
4
+
3
]
}
:
{
inputs
[
i
*
4
+
4
]
}
'
)
slice_str
=
','
.
join
(
slices
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
slice_str
}
]'
elif
len
(
inputs
)
==
4
:
# this case is for simple list
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
:
{
inputs
[
2
]
}
:
{
inputs
[
3
]
}
]'
else
:
raise
RuntimeError
(
'Unsupported slice pattern'
)
# the following Aten classes means these aten ops are not in torch.Tensor
class
AtenBool
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::Bool'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
class
AtenNot
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::__not__'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= not
{
inputs
[
0
]
}
'
class
AtenCat
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::cat'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
# ====================================
class
AtenTensors
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::full'
,
'aten::full_like'
,
'aten::empty_like'
,
'aten::ones_like'
,
'aten::zeros_like'
,
'aten::rand'
,
'aten::randn'
,
'aten::scalar_tensor'
,
'aten::new_full'
,
'aten::new_empty'
,
'aten::new_zeros'
,
'aten::arange'
,
'aten::tensor'
,
'aten::ones'
,
'aten::zeros'
,
'aten::as_tensor'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
self
.
type
)
# match number of inputs
overloaded_defs
=
[
len
(
s
.
arguments
)
for
s
in
schemas
]
matched
=
overloaded_defs
.
index
(
len
(
inputs
))
args_list
=
[]
for
idx
,
arg
in
enumerate
(
schemas
[
matched
].
arguments
):
if
arg
.
name
==
'dtype'
:
arg_str
=
f
'dtype=
{
scalar_type_to_pytorch_type
[
inputs_value
[
idx
]]
}
'
if
inputs_value
[
idx
]
is
not
None
else
''
elif
arg
.
name
==
'layout'
:
if
inputs_value
[
idx
]
is
not
None
:
arg_str
=
f
'layout=torch.strided'
print
(
'Warning: only support `torch.strided` for now!!!'
)
else
:
arg_str
=
''
elif
arg
.
name
==
'device'
:
arg_str
=
f
'device=torch.device(
{
inputs
[
idx
]
}
)'
if
inputs_value
[
idx
]
is
not
None
else
''
elif
arg
.
name
==
'memory_format'
:
arg_str
=
f
'memory_format=
{
mem_format
[
inputs_value
[
idx
]]
}
'
if
inputs_value
[
idx
]
is
not
None
else
''
elif
arg
.
name
==
'pin_memory'
:
# TODO: deal with this argument
continue
elif
arg
.
name
==
'requires_grad'
:
arg_str
=
f
'requires_grad=
{
inputs
[
idx
]
}
'
if
inputs_value
[
idx
]
else
''
elif
str
(
arg
.
type
).
startswith
(
'Optional['
):
arg_str
=
f
'
{
arg
.
name
}
=
{
inputs
[
idx
]
}
'
else
:
arg_str
=
f
'
{
inputs
[
idx
]
}
'
if
arg_str
!=
''
:
args_list
.
append
(
arg_str
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
if
hasattr
(
torch
,
op_name
):
return
f
'
{
output
}
= torch.
{
op_name
}
(
{
", "
.
join
(
args_list
)
}
)'
else
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
", "
.
join
(
args_list
[
1
:])
}
)'
# ====================================
class
AtenFloordiv
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::floordiv'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
//
{
inputs
[
1
]
}
'
class
AtenMul
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::mul'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
*
{
inputs
[
1
]
}
'
class
AtenLen
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::len'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= len(
{
inputs
[
0
]
}
)'
class
AtenIntImplicit
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::IntImplicit'
,
'aten::Float'
,
'aten::Int'
,
'aten::ScalarImplicit'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
self
.
type
.
endswith
(
'Implicit'
):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
'
elif
self
.
type
==
'aten::Int'
:
return
f
'
{
output
}
= int(
{
inputs
[
0
]
}
)'
elif
self
.
type
==
'aten::Float'
:
return
f
'
{
output
}
= float(
{
inputs
[
0
]
}
)'
raise
TypeError
(
f
'Unexpected type:
{
self
.
type
}
'
)
class
AtenIndex
(
PyTorchOperation
):
_ori_type_name
=
[
'aten::index'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
ManuallyChooseDef
=
{
'aten::flatten'
:
[(
'start_dim'
,
'int'
,
'0'
),
(
'end_dim'
,
'int'
,
'-1'
)],
'aten::split'
:
[(
'split_size'
,
'int'
,
'None'
),
(
'dim'
,
'int'
,
'0'
)],
# in v1.9 dtype is supported as input argument for view, but torch script does not support it
'aten::view'
:
[(
'size'
,
'List[int]'
,
'None'
)],
# NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed
# torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor
# torch.std(input, unbiased) Tensor
'aten::std'
:
[(
'dim'
,
'List[int]'
,
'None'
),
(
'unbiased'
,
'bool'
,
'True'
),
(
'keepdim'
,
'bool'
,
'False'
)]
}
TensorOpExceptions
=
{
'aten::sub'
:
lambda
output
,
inputs
:
f
'
{
output
}
=
{
inputs
[
0
]
}
-
{
inputs
[
1
]
}
'
,
# example: x.size(1) - 3
'aten::add'
:
lambda
output
,
inputs
:
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
# example: input.shape[0] + 5
}
TorchOpExclude
=
[
'aten::Size'
,
'aten::as_tensor'
,
'aten::device'
,
'aten::manual_seed'
,
'aten::quantized_gru'
,
'aten::quantized_lstm'
,
'aten::save'
,
'aten::tensor'
,
'aten::wait'
]
def
_hidden
(
name
):
return
name
.
startswith
(
'_'
)
and
not
name
.
startswith
(
'__'
)
def
_emit_args
(
args
):
# filter out the `out` argument here
return
[(
arg
.
name
,
str
(
arg
.
type
),
str
(
arg
.
default_value
))
for
arg
in
args
]
# if arg.name != 'out'
def
_get_tensor_ops
():
def
is_tensor_method
(
schema
):
if
len
(
schema
.
arguments
)
==
0
:
return
False
self
=
schema
.
arguments
[
0
]
if
self
.
name
!=
'self'
:
return
False
if
not
self
.
type
.
isSubtypeOf
(
torch
.
_C
.
TensorType
.
get
()):
return
False
return
True
op_args
=
{}
# discover methods
for
elem
in
dir
(
torch
.
Tensor
):
if
not
_hidden
(
elem
):
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
"aten::"
+
elem
)
for
schema
in
schemas
:
if
is_tensor_method
(
schema
):
op_name
=
'aten::'
+
elem
args
=
_emit_args
(
schema
.
arguments
[
1
:])
if
op_name
in
op_args
:
op_args
[
op_name
].
append
(
args
)
else
:
op_args
[
op_name
]
=
[
args
]
return
op_args
.
keys
(),
op_args
def
_get_torch_ops
():
torch_op_args
=
{}
for
mod
in
torch
.
jit
.
_builtins
.
_modules_containing_builtins
:
# type: ignore
name
=
mod
.
__name__
if
name
==
'torch._C._nn'
:
continue
# only process 'torch.XXX'
for
elem
in
dir
(
mod
):
builtin
=
torch
.
jit
.
_builtins
.
_find_builtin
(
getattr
(
mod
,
elem
))
# type: ignore
if
builtin
is
not
None
:
schemas
=
torch
.
_C
.
_jit_get_schemas_for_operator
(
builtin
)
for
schema
in
schemas
:
# remove _tan but not __and__
if
not
_hidden
(
elem
):
op_name
=
'aten::'
+
elem
if
len
(
schema
.
arguments
)
>
0
and
schema
.
arguments
[
0
].
name
==
'self'
:
continue
args
=
_emit_args
(
schema
.
arguments
)
if
op_name
in
torch_op_args
:
torch_op_args
[
op_name
].
append
(
args
)
else
:
torch_op_args
[
op_name
]
=
[
args
]
return
torch_op_args
.
keys
(),
torch_op_args
def
_get_torch_ops_exclude_tensor_ops
():
tensor_op_names
,
_
=
_get_tensor_ops
()
torch_op_names
,
torch_ops
=
_get_torch_ops
()
torch_exclude_ops
=
{}
for
name
in
torch_op_names
:
if
name
not
in
tensor_op_names
:
if
name
not
in
TorchOpExclude
:
# exclude the ops that are not in
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
torch_exclude_ops
[
name
]
=
torch_ops
[
name
]
return
torch_exclude_ops
.
keys
(),
torch_exclude_ops
class
TensorOps
(
PyTorchOperation
):
"""
corresponding to _get_tensor_ops in torch.jit.supported_ops
"""
_ori_type_name
,
_op_args
=
_get_tensor_ops
()
comparison_ops
=
{
'aten::eq'
:
'=='
,
'aten::ne'
:
'!='
,
'aten::le'
:
'<='
,
'aten::ge'
:
'>='
,
'aten::lt'
:
'<'
,
'aten::gt'
:
'>'
}
@
staticmethod
def
_get_matched_args
(
_type
,
inputs
):
def
has_same_arg_name
(
matched
):
concated_names
=
[]
for
i
,
each
in
enumerate
(
matched
):
name
=
','
.
join
([
arg
[
0
]
for
arg
in
each
])
concated_names
.
append
(
name
)
for
i
in
range
(
len
(
concated_names
)
-
1
):
if
concated_names
[
i
]
!=
concated_names
[
i
+
1
]:
return
False
return
True
overloaded_defs
=
TensorOps
.
_op_args
[
_type
]
matched
=
[]
for
each
in
overloaded_defs
:
# plus 1 because we skip the first argument when generating tensor op def
if
len
(
each
)
+
1
==
len
(
inputs
):
matched
.
append
(
each
)
if
len
(
matched
)
==
1
:
return
matched
[
0
]
elif
len
(
matched
)
>
1
:
# TODO: match with arg's type. manually choose for now
if
has_same_arg_name
(
matched
):
# return any one is okay
return
matched
[
0
]
elif
_type
in
ManuallyChooseDef
:
return
ManuallyChooseDef
[
_type
]
else
:
raise
RuntimeError
(
f
'tensor op type
{
_type
}
has more than one matched:
{
matched
}
'
)
else
:
if
_type
in
TensorOpExceptions
:
return
None
raise
RuntimeError
(
f
'tensor op type
{
_type
}
has no matched'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
# TODO: deal with conditional ops
if
self
.
type
in
TensorOps
.
comparison_ops
:
return
f
'
{
output
}
= (
{
inputs
[
0
]
}
{
TensorOps
.
comparison_ops
[
self
.
type
]
}
{
inputs
[
1
]
}
)'
matched_args
=
TensorOps
.
_get_matched_args
(
self
.
type
,
inputs
)
if
matched_args
is
None
:
return
TensorOpExceptions
[
self
.
type
](
output
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
+
1
]
}
'
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
args_str
}
)'
class
TorchOps
(
PyTorchOperation
):
"""
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
"""
_ori_type_name
,
_op_args
=
_get_torch_ops_exclude_tensor_ops
()
# add 'aten::pixel_shuffle'
_op_args
[
'aten::pixel_shuffle'
]
=
[[(
'input'
,
'Tensor'
,
'None'
),
(
'upscale_factor'
,
'Optional[int]'
,
'None'
)]]
_ori_type_name
=
_op_args
.
keys
()
@
staticmethod
def
_get_matched_args
(
_type
,
inputs
):
def
has_same_arg_name
(
matched
):
concated_names
=
[]
for
i
,
each
in
enumerate
(
matched
):
name
=
','
.
join
([
arg
[
0
]
for
arg
in
each
])
concated_names
.
append
(
name
)
for
i
in
range
(
len
(
concated_names
)
-
1
):
if
concated_names
[
i
]
!=
concated_names
[
i
+
1
]:
return
False
return
True
overloaded_defs
=
TorchOps
.
_op_args
[
_type
]
matched
=
[]
for
each
in
overloaded_defs
:
if
len
(
each
)
==
len
(
inputs
):
matched
.
append
(
each
)
if
len
(
matched
)
==
1
:
return
matched
[
0
]
elif
len
(
matched
)
>
1
:
# TODO: match with arg's type. manually choose for now
if
has_same_arg_name
(
matched
):
# return any one is okay
return
matched
[
0
]
else
:
raise
RuntimeError
(
f
'torch op type
{
_type
}
has more than one matched:
{
matched
}
'
)
else
:
raise
RuntimeError
(
f
'torch op type
{
_type
}
has no matched'
)
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
matched_args
=
TorchOps
.
_get_matched_args
(
self
.
type
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
]
}
'
if
t
.
startswith
(
'Optional['
)
else
f
'
{
inputs
[
i
]
}
'
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
return
f
'
{
output
}
= torch.
{
op_name
}
(
{
args_str
}
)'
class
AtenAvgpool2d
(
PyTorchOperation
):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name
=
[
'aten::avg_pool2d'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= F.avg_pool2d(
{
", "
.
join
(
inputs
)
}
)'
class
ToDevice
(
PyTorchOperation
):
_artificial_op_name
=
"ToDevice"
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
,
attributes
:
Dict
[
str
,
Any
]
=
{}):
self
.
type
=
"ToDevice"
self
.
device
=
parameters
[
'device'
]
self
.
overridden_device_repr
=
None
self
.
src
=
parameters
[
'src'
]
self
.
dst
=
parameters
[
'dst'
]
def
override_device_repr
(
self
,
device_repr
):
# CUDA GPUDevice may remap GPU physical ID to CUDA ID. The device repr is different from GPUDevice.device_repr()
# override_device_repr will be called in pytorch.graph_to_pytorch_model to replace device_repr with the correct
# CUDA ID, e.g., when a job uses Physical GPU-1,2, its CUDA ID should be "cuda:0" and "cuda:1".
# self.device.device_repr() would return "cuda:1" and "cuda:2", but override_device_repr should be "cuda:0" and
# "cuda:1"
self
.
overridden_device_repr
=
device_repr
def
__repr__
(
self
):
if
self
.
overridden_device_repr
is
None
:
return
f
'to("
{
self
.
device
.
device_repr
()
}
")'
else
:
return
f
'to("
{
self
.
overridden_device_repr
}
")'
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
if
self
.
overridden_device_repr
is
None
:
forward_code
=
f
'
{
output
}
=
{
inputs
[
0
]
}
.to("
{
self
.
device
.
device_repr
()
}
")'
else
:
forward_code
=
f
'
{
output
}
=
{
inputs
[
0
]
}
.to("
{
self
.
overridden_device_repr
}
")'
return
forward_code
class
AtenDet
(
PyTorchOperation
):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name
=
[
'aten::linalg_det'
]
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
],
inputs_value
:
List
[
Any
])
->
str
:
return
f
'
{
output
}
= torch.det(
{
inputs
[
0
]
}
)'
from
nni.nas.execution.pytorch.op_def
import
*
Prev
1
…
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment