Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8c8db374
Unverified
Commit
8c8db374
authored
Feb 03, 2021
by
Yuge Zhang
Committed by
GitHub
Feb 03, 2021
Browse files
[Retiarii] Improve high-level API interface and add implementation of ValueChoice (#3349)
parent
be3a6966
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
641 additions
and
178 deletions
+641
-178
docs/en_US/NAS/retiarii/ApiReference.rst
docs/en_US/NAS/retiarii/ApiReference.rst
+8
-2
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+23
-13
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+1
-1
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+1
-1
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+6
-26
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+0
-33
nni/retiarii/nn/pytorch/__init__.py
nni/retiarii/nn/pytorch/__init__.py
+1
-0
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+284
-0
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+97
-0
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+1
-95
nni/retiarii/trainer/__init__.py
nni/retiarii/trainer/__init__.py
+0
-1
test/retiarii_test/darts/test.py
test/retiarii_test/darts/test.py
+1
-1
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+1
-1
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+1
-1
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+1
-1
test/ut/retiarii/test_dedup_input.py
test/ut/retiarii/test_dedup_input.py
+1
-1
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+1
-1
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+213
-0
No files found.
docs/en_US/NAS/retiarii/ApiReference.rst
View file @
8c8db374
...
...
@@ -12,6 +12,12 @@ Inline Mutation APIs
.. autoclass:: nni.retiarii.nn.pytorch.InputChoice
:members:
.. autoclass:: nni.retiarii.nn.pytorch.ValueChoice
:members:
.. autoclass:: nni.retiarii.nn.pytorch.ChosenInputs
:members:
Graph Mutation APIs
-------------------
...
...
@@ -36,10 +42,10 @@ Graph Mutation APIs
Trainers
--------
.. autoclass:: nni.retiarii.trainer.PyTorchImageClassificationTrainer
.. autoclass:: nni.retiarii.trainer.
pytorch.
PyTorchImageClassificationTrainer
:members:
.. autoclass:: nni.retiarii.trainer.PyTorchMultiModelTrainer
.. autoclass:: nni.retiarii.trainer.
pytorch.
PyTorchMultiModelTrainer
:members:
Oneshot Trainers
...
...
nni/retiarii/converter/graph_gen.py
View file @
8c8db374
...
...
@@ -408,25 +408,33 @@ class GraphConverter:
self
.
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
self
,
module
):
m_attrs
=
{}
candidates
=
module
.
op_candidates
choices
=
[]
for
cand
in
candidates
:
assert
id
(
cand
)
in
self
.
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
for
cand
in
list
(
module
):
assert
id
(
cand
)
in
self
.
modules_arg
,
\
f
'Module not recorded:
{
id
(
cand
)
}
. '
\
'Try to import from `retiarii.nn` if you are using torch.nn module or '
\
'annotate your customized module with @blackbox_module.'
assert
isinstance
(
self
.
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
self
.
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
return
{
'candidates'
:
choices
,
'label'
:
module
.
label
}
def
_handle_inputchoice
(
self
,
module
):
m_attrs
=
{}
m_attrs
[
'n_candidates'
]
=
module
.
n_candidates
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
return
{
'n_candidates'
:
module
.
n_candidates
,
'n_chosen'
:
module
.
n_chosen
,
'reduction'
:
module
.
reduction
,
'label'
:
module
.
label
}
def
_handle_valuechoice
(
self
,
module
):
return
{
'candidates'
:
module
.
candidates
,
'label'
:
module
.
label
}
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
"""
...
...
@@ -461,6 +469,8 @@ class GraphConverter:
m_attrs
=
self
.
_handle_layerchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
self
.
_handle_inputchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
ValueChoice
:
m_attrs
=
self
.
_handle_valuechoice
(
module
)
elif
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
elif
original_type_name
in
torch
.
nn
.
__dict__
:
...
...
nni/retiarii/execution/logical_optimizer/logical_plan.py
View file @
8c8db374
...
...
@@ -149,7 +149,7 @@ class LogicalPlan:
phy_model
.
training_config
.
kwargs
[
'model_cls'
]
=
phy_graph
.
name
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
]
=
[]
# FIXME: allow user to specify
phy_model
.
training_config
.
module
=
'nni.retiarii.trainer.PyTorchMultiModelTrainer'
phy_model
.
training_config
.
module
=
'nni.retiarii.trainer.
pytorch.
PyTorchMultiModelTrainer'
# merge sub-graphs
for
model
in
multi_model_placement
:
...
...
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
View file @
8c8db374
...
...
@@ -6,7 +6,7 @@ from .interface import AbstractOptimizer
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
OriginNode
,
PhysicalDevice
)
_supported_training_modules
=
[
'nni.retiarii.trainer.PyTorchImageClassificationTrainer'
]
_supported_training_modules
=
[
'nni.retiarii.trainer.
pytorch.
PyTorchImageClassificationTrainer'
]
class
DedupInputNode
(
AbstractLogicalNode
):
...
...
nni/retiarii/experiment.py
View file @
8c8db374
...
...
@@ -15,15 +15,13 @@ from .graph import Model
from
.utils
import
get_records
from
.integration
import
RetiariiAdvisor
from
.converter
import
convert_to_graph
from
.mutator
import
Mutator
,
LayerChoiceMutator
,
InputChoiceMutator
from
.mutator
import
Mutator
from
.trainer.interface
import
BaseTrainer
,
BaseOneShotTrainer
from
.strategies.strategy
import
BaseStrategy
from
.trainer
.pytorch
import
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePath
Trainer
from
.trainer
import
BaseOneShot
Trainer
_logger
=
logging
.
getLogger
(
__name__
)
OneShotTrainers
=
(
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
...
...
@@ -94,28 +92,10 @@ class RetiariiExperiment(Experiment):
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
def
_process_inline_mutation
(
self
,
base_model
):
"""
the mutators are order independent
"""
lc_nodes
=
base_model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.nn.LayerChoice'
)
ic_nodes
=
base_model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.nn.InputChoice'
)
if
not
lc_nodes
and
not
ic_nodes
:
return
None
applied_mutators
=
[]
for
node
in
lc_nodes
:
mutator
=
LayerChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'choices'
])
applied_mutators
.
append
(
mutator
)
for
node
in
ic_nodes
:
mutator
=
InputChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'n_candidates'
],
node
.
operation
.
parameters
[
'n_chosen'
],
node
.
operation
.
parameters
[
'reduction'
])
applied_mutators
.
append
(
mutator
)
return
applied_mutators
def
_start_strategy
(
self
):
import
torch
from
.nn.pytorch.mutator
import
process_inline_mutation
try
:
script_module
=
torch
.
jit
.
script
(
self
.
base_model
)
except
Exception
as
e
:
...
...
@@ -131,7 +111,7 @@ class RetiariiExperiment(Experiment):
base_model_ir
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
# handle inline mutations
mutators
=
self
.
_
process_inline_mutation
(
base_model_ir
)
mutators
=
process_inline_mutation
(
base_model_ir
)
if
mutators
is
not
None
and
self
.
applied_mutators
:
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators,
\
do not use mutators when you use LayerChoice/InputChoice'
)
...
...
@@ -165,7 +145,7 @@ class RetiariiExperiment(Experiment):
Run the experiment.
This function will block until experiment finish or error.
"""
if
isinstance
(
self
.
trainer
,
OneShotTrainer
s
):
if
isinstance
(
self
.
trainer
,
Base
OneShotTrainer
):
self
.
trainer
.
fit
()
else
:
assert
config
is
not
None
,
'You are using classic search mode, config cannot be None!'
...
...
nni/retiarii/mutator.py
View file @
8c8db374
...
...
@@ -105,36 +105,3 @@ class _RecorderSampler(Sampler):
def
choice
(
self
,
candidates
:
List
[
Choice
],
*
args
)
->
Choice
:
self
.
recorded_candidates
.
append
(
candidates
)
return
candidates
[
0
]
# the following is for inline mutation
class
LayerChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
candidates
:
List
):
super
().
__init__
()
self
.
node_name
=
node_name
self
.
candidates
=
candidates
def
mutate
(
self
,
model
):
target
=
model
.
get_node_by_name
(
self
.
node_name
)
indexes
=
[
i
for
i
in
range
(
len
(
self
.
candidates
))]
chosen_index
=
self
.
choice
(
indexes
)
chosen_cand
=
self
.
candidates
[
chosen_index
]
target
.
update_operation
(
chosen_cand
[
'type'
],
chosen_cand
[
'parameters'
])
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
n_candidates
:
int
,
n_chosen
:
int
,
reduction
:
str
):
super
().
__init__
()
self
.
node_name
=
node_name
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
def
mutate
(
self
,
model
):
target
=
model
.
get_node_by_name
(
self
.
node_name
)
candidates
=
[
i
for
i
in
range
(
self
.
n_candidates
)]
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
self
.
n_chosen
)]
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs'
,
{
'chosen'
:
chosen
,
'reduction'
:
self
.
reduction
})
nni/retiarii/nn/pytorch/__init__.py
View file @
8c8db374
from
.api
import
*
from
.nn
import
*
nni/retiarii/nn/pytorch/api.py
0 → 100644
View file @
8c8db374
from
collections
import
OrderedDict
from
typing
import
Any
,
List
,
Union
,
Dict
import
warnings
import
torch
import
torch.nn
as
nn
from
...utils
import
uid
,
add_record
,
del_record
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
class
LayerChoice
(
nn
.
Module
):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
Layer choice does not allow itself to be nested.
Parameters
----------
candidates : list of nn.Module or OrderedDict
A module list to be selected from.
label : str
Identifier of the layer choice.
Attributes
----------
length : int
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
names : list of str
Names of candidates.
choices : list of Module
Deprecated. A list of all candidate modules in the layer choice module.
``list(layer_choice)`` is recommended, which will serve the same purpose.
Notes
-----
``candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
label
:
str
=
None
,
**
kwargs
):
super
(
LayerChoice
,
self
).
__init__
()
if
'key'
in
kwargs
:
warnings
.
warn
(
f
'"key" is deprecated. Assuming label.'
)
label
=
kwargs
[
'key'
]
if
'return_mask'
in
kwargs
:
warnings
.
warn
(
f
'"return_mask" is deprecated. Ignoring...'
)
if
'reduction'
in
kwargs
:
warnings
.
warn
(
f
'"reduction" is deprecated. Ignoring...'
)
self
.
candidates
=
candidates
self
.
_label
=
label
if
label
is
not
None
else
f
'layerchoice_
{
uid
()
}
'
self
.
names
=
[]
if
isinstance
(
candidates
,
OrderedDict
):
for
name
,
module
in
candidates
.
items
():
assert
name
not
in
[
"length"
,
"reduction"
,
"return_mask"
,
"_key"
,
"key"
,
"names"
],
\
"Please don't use a reserved name '{}' for your module."
.
format
(
name
)
self
.
add_module
(
name
,
module
)
self
.
names
.
append
(
name
)
elif
isinstance
(
candidates
,
list
):
for
i
,
module
in
enumerate
(
candidates
):
self
.
add_module
(
str
(
i
),
module
)
self
.
names
.
append
(
str
(
i
))
else
:
raise
TypeError
(
"Unsupported candidates type: {}"
.
format
(
type
(
candidates
)))
@
property
def
key
(
self
):
return
self
.
_key
()
@
torch
.
jit
.
ignore
def
_key
(
self
):
warnings
.
warn
(
'Using key to access the identifier of LayerChoice is deprecated. Please use label instead.'
,
category
=
DeprecationWarning
)
return
self
.
_label
@
property
def
label
(
self
):
return
self
.
_label
def
__getitem__
(
self
,
idx
):
if
isinstance
(
idx
,
str
):
return
self
.
_modules
[
idx
]
return
list
(
self
)[
idx
]
def
__setitem__
(
self
,
idx
,
module
):
key
=
idx
if
isinstance
(
idx
,
str
)
else
self
.
names
[
idx
]
return
setattr
(
self
,
key
,
module
)
def
__delitem__
(
self
,
idx
):
if
isinstance
(
idx
,
slice
):
for
key
in
self
.
names
[
idx
]:
delattr
(
self
,
key
)
else
:
if
isinstance
(
idx
,
str
):
key
,
idx
=
idx
,
self
.
names
.
index
(
idx
)
else
:
key
=
self
.
names
[
idx
]
delattr
(
self
,
key
)
del
self
.
names
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
names
)
def
__iter__
(
self
):
return
map
(
lambda
name
:
self
.
_modules
[
name
],
self
.
names
)
@
property
def
choices
(
self
):
return
self
.
_choices
()
@
torch
.
jit
.
ignore
def
_choices
(
self
):
warnings
.
warn
(
"layer_choice.choices is deprecated. Use `list(layer_choice)` instead."
,
category
=
DeprecationWarning
)
return
list
(
self
)
def
forward
(
self
,
x
):
warnings
.
warn
(
'You should not run forward of this module directly.'
)
return
x
class
InputChoice
(
nn
.
Module
):
"""
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys).
Use ``reduction`` to specify how chosen inputs are reduced into one output. A few options are:
* ``none``: do nothing and return the list directly.
* ``sum``: summing all the chosen inputs.
* ``mean``: taking the average of all chosen inputs.
* ``concat``: concatenate all chosen inputs at dimension 1.
We don't support customizing reduction yet.
Parameters
----------
n_candidates : int
Number of inputs to choose from. It is required.
n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``.
label : str
Identifier of the input choice.
"""
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
str
=
None
,
**
kwargs
):
super
(
InputChoice
,
self
).
__init__
()
if
'key'
in
kwargs
:
warnings
.
warn
(
f
'"key" is deprecated. Assuming label.'
)
label
=
kwargs
[
'key'
]
if
'return_mask'
in
kwargs
:
warnings
.
warn
(
f
'"return_mask" is deprecated. Ignoring...'
)
if
'choose_from'
in
kwargs
:
warnings
.
warn
(
f
'"reduction" is deprecated. Ignoring...'
)
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
assert
self
.
reduction
in
[
'mean'
,
'concat'
,
'sum'
,
'none'
]
self
.
_label
=
label
if
label
is
not
None
else
f
'inputchoice_
{
uid
()
}
'
@
property
def
key
(
self
):
return
self
.
_key
()
@
torch
.
jit
.
ignore
def
_key
(
self
):
warnings
.
warn
(
'Using key to access the identifier of InputChoice is deprecated. Please use label instead.'
,
category
=
DeprecationWarning
)
return
self
.
_label
@
property
def
label
(
self
):
return
self
.
_label
def
forward
(
self
,
candidate_inputs
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
warnings
.
warn
(
'You should not run forward of this module directly.'
)
return
candidate_inputs
[
0
]
class
ValueChoice
(
nn
.
Module
):
"""
ValueChoice is to choose one from ``candidates``.
Should initialize the values to choose from in init and call the module in forward to get the chosen value.
A common use is to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```. For example,
.. code-block:: python
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.])
def forward(self, x):
return F.dropout(x, self.dropout_rate())
The following use case is currently not supported because ValueChoice cannot be called in ``__init__``.
Please use LayerChoice as a workaround.
.. code-block:: python
# in __init__ code
self.kernel_size = nn.ValueChoice([3, 5])
self.conv = nn.Conv2d(3, self.out_channels, kernel_size=self.kernel_size())
Parameters
----------
candidates : list
List of values to choose from.
label : str
Identifier of the value choice.
"""
def
__init__
(
self
,
candidates
:
List
[
Any
],
label
:
str
=
None
):
super
().
__init__
()
self
.
candidates
=
candidates
self
.
_label
=
label
if
label
is
not
None
else
f
'valuechoice_
{
uid
()
}
'
@
property
def
label
(
self
):
return
self
.
_label
def
forward
(
self
):
warnings
.
warn
(
'You should not run forward of this module directly.'
)
return
self
.
candidates
[
0
]
class
Placeholder
(
nn
.
Module
):
# TODO: docstring
def
__init__
(
self
,
label
,
related_info
):
add_record
(
id
(
self
),
related_info
)
self
.
label
=
label
self
.
related_info
=
related_info
super
(
Placeholder
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
def
__del__
(
self
):
del_record
(
id
(
self
))
class
ChosenInputs
(
nn
.
Module
):
"""
A module that chooses from a tensor list and outputs a reduced tensor.
The already-chosen version of InputChoice.
"""
def
__init__
(
self
,
chosen
:
List
[
int
],
reduction
:
str
):
super
().
__init__
()
self
.
chosen
=
chosen
self
.
reduction
=
reduction
def
forward
(
self
,
candidate_inputs
):
return
self
.
_tensor_reduction
(
self
.
reduction
,
[
candidate_inputs
[
i
]
for
i
in
self
.
chosen
])
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
'none'
:
return
tensor_list
if
not
tensor_list
:
return
None
# empty. return None for now
if
len
(
tensor_list
)
==
1
:
return
tensor_list
[
0
]
if
reduction_type
==
'sum'
:
return
sum
(
tensor_list
)
if
reduction_type
==
'mean'
:
return
sum
(
tensor_list
)
/
len
(
tensor_list
)
if
reduction_type
==
'concat'
:
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
raise
ValueError
(
f
'Unrecognized reduction policy: "
{
reduction_type
}
"'
)
nni/retiarii/nn/pytorch/mutator.py
0 → 100644
View file @
8c8db374
from
typing
import
Any
,
List
,
Optional
from
...mutator
import
Mutator
from
...graph
import
Model
,
Node
class
LayerChoiceMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Node
]):
super
().
__init__
()
self
.
nodes
=
nodes
def
mutate
(
self
,
model
):
n_candidates
=
len
(
self
.
nodes
[
0
].
operation
.
parameters
[
'candidates'
])
indices
=
list
(
range
(
n_candidates
))
chosen_index
=
self
.
choice
(
indices
)
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
chosen_cand
=
node
.
operation
.
parameters
[
'candidates'
][
chosen_index
]
target
.
update_operation
(
chosen_cand
[
'type'
],
chosen_cand
[
'parameters'
])
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Node
]):
super
().
__init__
()
self
.
nodes
=
nodes
def
mutate
(
self
,
model
):
n_candidates
=
self
.
nodes
[
0
].
operation
.
parameters
[
'n_candidates'
]
n_chosen
=
self
.
nodes
[
0
].
operation
.
parameters
[
'n_chosen'
]
candidates
=
list
(
range
(
n_candidates
))
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
n_chosen
)]
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.ChosenInputs'
,
{
'chosen'
:
chosen
,
'reduction'
:
node
.
operation
.
parameters
[
'reduction'
]})
class
ValueChoiceMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Node
],
candidates
:
List
[
Any
]):
super
().
__init__
()
self
.
nodes
=
nodes
self
.
candidates
=
candidates
def
mutate
(
self
,
model
):
chosen
=
self
.
choice
(
self
.
candidates
)
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
target
.
update_operation
(
'prim::Constant'
,
{
'value'
:
chosen
})
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
applied_mutators
=
[]
lc_nodes
=
_group_by_label
(
model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.api.LayerChoice'
))
for
node_list
in
lc_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
len
(
node
.
operation
.
parameters
[
'candidates'
]),
node_list
)),
\
'Layer choice with the same label must have the same number of candidates.'
mutator
=
LayerChoiceMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
ic_nodes
=
_group_by_label
(
model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.api.InputChoice'
))
for
node_list
in
ic_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'n_candidates'
],
node_list
))
and
\
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'n_chosen'
],
node_list
)),
\
'Input choice with the same label must have the same number of candidates.'
mutator
=
InputChoiceMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
vc_nodes
=
_group_by_label
(
model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.api.ValueChoice'
))
for
node_list
in
vc_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'candidates'
],
node_list
)),
\
'Value choice with the same label must have the same candidates.'
mutator
=
ValueChoiceMutator
(
node_list
,
node_list
[
0
].
operation
.
parameters
[
'candidates'
])
applied_mutators
.
append
(
mutator
)
if
applied_mutators
:
return
applied_mutators
return
None
def
_is_all_equal
(
lst
):
last
=
None
for
x
in
lst
:
if
last
is
not
None
and
last
!=
x
:
return
False
last
=
x
return
True
def
_group_by_label
(
nodes
:
List
[
Node
])
->
List
[
List
[
Node
]]:
result
=
{}
for
node
in
nodes
:
label
=
node
.
operation
.
parameters
[
'label'
]
if
label
not
in
result
:
result
[
label
]
=
[]
result
[
label
].
append
(
node
)
return
list
(
result
.
values
())
nni/retiarii/nn/pytorch/nn.py
View file @
8c8db374
import
logging
from
typing
import
Any
,
List
import
torch
import
torch.nn
as
nn
from
...utils
import
add_record
,
blackbox_module
,
del_record
,
uid
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
from
...utils
import
add_record
,
blackbox_module
,
del_record
,
version_larger_equal
# NOTE: support pytorch version >= 1.5.0
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity'
,
'Linear'
,
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Threshold'
,
'ReLU'
,
'Hardtanh'
,
'ReLU6'
,
...
...
@@ -40,94 +34,6 @@ if version_larger_equal(torch.__version__, '1.7.0'):
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
class
LayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
op_candidates
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
super
(
LayerChoice
,
self
).
__init__
()
self
.
op_candidates
=
op_candidates
self
.
label
=
key
if
key
is
not
None
else
f
'layerchoice_
{
uid
()
}
'
self
.
key
=
self
.
label
# deprecated, for backward compatibility
for
i
,
module
in
enumerate
(
op_candidates
):
# deprecated, for backward compatibility
self
.
add_module
(
str
(
i
),
module
)
if
reduction
or
return_mask
:
_logger
.
warning
(
'input arguments `reduction` and `return_mask` are deprecated!'
)
def
forward
(
self
,
x
):
return
x
class
InputChoice
(
nn
.
Module
):
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
1
,
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
super
(
InputChoice
,
self
).
__init__
()
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
label
=
key
if
key
is
not
None
else
f
'inputchoice_
{
uid
()
}
'
self
.
key
=
self
.
label
# deprecated, for backward compatibility
if
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
def
forward
(
self
,
candidate_inputs
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# fake return
return
torch
.
tensor
(
candidate_inputs
)
# pylint: disable=not-callable
class
ValueChoice
:
"""
The instance of this class can only be used as input argument,
when instantiating a pytorch module.
TODO: can also be used in training approach
"""
def
__init__
(
self
,
candidate_values
:
List
[
Any
]):
self
.
candidate_values
=
candidate_values
class
Placeholder
(
nn
.
Module
):
def
__init__
(
self
,
label
,
related_info
):
add_record
(
id
(
self
),
related_info
)
self
.
label
=
label
self
.
related_info
=
related_info
super
(
Placeholder
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
def
__del__
(
self
):
del_record
(
id
(
self
))
class
ChosenInputs
(
nn
.
Module
):
"""
"""
def
__init__
(
self
,
chosen
:
List
[
int
],
reduction
:
str
):
super
().
__init__
()
self
.
chosen
=
chosen
self
.
reduction
=
reduction
def
forward
(
self
,
candidate_inputs
):
return
self
.
_tensor_reduction
(
self
.
reduction
,
[
candidate_inputs
[
i
]
for
i
in
self
.
chosen
])
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
"none"
:
return
tensor_list
if
not
tensor_list
:
return
None
# empty. return None for now
if
len
(
tensor_list
)
==
1
:
return
tensor_list
[
0
]
if
reduction_type
==
"sum"
:
return
sum
(
tensor_list
)
if
reduction_type
==
"mean"
:
return
sum
(
tensor_list
)
/
len
(
tensor_list
)
if
reduction_type
==
"concat"
:
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
# the following are pytorch modules
Module
=
nn
.
Module
...
...
nni/retiarii/trainer/__init__.py
View file @
8c8db374
from
.interface
import
BaseTrainer
,
BaseOneShotTrainer
from
.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
test/retiarii_test/darts/test.py
View file @
8c8db374
...
...
@@ -6,7 +6,7 @@ from pathlib import Path
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategies
import
TPEStrategy
,
RandomStrategy
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
nni.retiarii.trainer
.pytorch
import
PyTorchImageClassificationTrainer
from
darts_model
import
CNN
...
...
test/retiarii_test/mnasnet/test.py
View file @
8c8db374
...
...
@@ -3,7 +3,7 @@ import sys
import
torch
from
pathlib
import
Path
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
nni.retiarii.trainer
.pytorch
import
PyTorchImageClassificationTrainer
from
base_mnasnet
import
MNASNet
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
...
...
test/retiarii_test/mnist/test.py
View file @
8c8db374
...
...
@@ -4,7 +4,7 @@ import nni.retiarii.nn.pytorch as nn
import
torch.nn.functional
as
F
from
nni.retiarii.experiment
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.strategies
import
RandomStrategy
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
nni.retiarii.trainer
.pytorch
import
PyTorchImageClassificationTrainer
class
Net
(
nn
.
Module
):
...
...
test/ut/retiarii/test_cgo_engine.py
View file @
8c8db374
...
...
@@ -18,7 +18,7 @@ from nni.retiarii import Model, Node
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.trainer
.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_dedup_input.py
View file @
8c8db374
...
...
@@ -17,7 +17,7 @@ from nni.retiarii import Model, Node
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.trainer
.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_engine.py
View file @
8c8db374
...
...
@@ -9,7 +9,7 @@ import nni
from
nni.retiarii
import
Model
,
submit_models
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.integration
import
RetiariiAdvisor
,
register_advisor
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
nni.retiarii.trainer
.pytorch
import
PyTorchImageClassificationTrainer
from
nni.retiarii.utils
import
import_
...
...
test/ut/retiarii/test_highlevel_apis.py
0 → 100644
View file @
8c8db374
import
random
import
unittest
import
nni.retiarii.nn.pytorch
as
nn
import
torch
import
torch.nn.functional
as
F
from
nni.retiarii
import
Sampler
,
blackbox_module
from
nni.retiarii.converter
import
convert_to_graph
from
nni.retiarii.codegen
import
model_to_pytorch_script
from
nni.retiarii.nn.pytorch.mutator
import
process_inline_mutation
class
EnuemrateSampler
(
Sampler
):
def
__init__
(
self
):
self
.
index
=
0
def
choice
(
self
,
candidates
,
*
args
,
**
kwargs
):
choice
=
candidates
[
self
.
index
%
len
(
candidates
)]
self
.
index
+=
1
return
choice
class
RandomSampler
(
Sampler
):
def
__init__
(
self
):
self
.
counter
=
0
def
choice
(
self
,
candidates
,
*
args
,
**
kwargs
):
self
.
counter
+=
1
return
random
.
choice
(
candidates
)
@
blackbox_module
class
MutableConv
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
)
self
.
conv2
=
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
1
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
index
:
int
):
if
index
==
0
:
return
self
.
conv1
(
x
)
else
:
return
self
.
conv2
(
x
)
class
TestHighLevelAPI
(
unittest
.
TestCase
):
def
_convert_to_ir
(
self
,
model
):
script_module
=
torch
.
jit
.
script
(
model
)
return
convert_to_graph
(
script_module
,
model
)
def
_get_converted_pytorch_model
(
self
,
model_ir
):
model_code
=
model_to_pytorch_script
(
model_ir
)
exec_vars
=
{}
exec
(
model_code
+
'
\n\n
converted_model = _model()'
,
exec_vars
)
return
exec_vars
[
'converted_model'
]
def
test_layer_choice
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
module
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
1
)
])
def
forward
(
self
,
x
):
return
self
.
module
(
x
)
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
5
,
3
,
3
]))
def
test_input_choice
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
)
self
.
conv2
=
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
1
)
self
.
input
=
nn
.
InputChoice
(
2
)
def
forward
(
self
,
x
):
x1
=
self
.
conv1
(
x
)
x2
=
self
.
conv2
(
x
)
return
self
.
input
([
x1
,
x2
])
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
5
,
3
,
3
]))
def
test_chosen_inputs
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
reduction
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
)
self
.
conv2
=
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
)
self
.
input
=
nn
.
InputChoice
(
2
,
n_chosen
=
2
,
reduction
=
reduction
)
def
forward
(
self
,
x
):
x1
=
self
.
conv1
(
x
)
x2
=
self
.
conv2
(
x
)
return
self
.
input
([
x1
,
x2
])
for
reduction
in
[
'none'
,
'sum'
,
'mean'
,
'concat'
]:
model
=
self
.
_convert_to_ir
(
Net
(
reduction
))
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
model
=
mutator
.
apply
(
model
)
result
=
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
if
reduction
==
'none'
:
self
.
assertEqual
(
len
(
result
),
2
)
self
.
assertEqual
(
result
[
0
].
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertEqual
(
result
[
1
].
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
elif
reduction
==
'concat'
:
self
.
assertEqual
(
result
.
size
(),
torch
.
Size
([
1
,
6
,
3
,
3
]))
else
:
self
.
assertEqual
(
result
.
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
def
test_value_choice
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
index
=
nn
.
ValueChoice
([
0
,
1
])
self
.
conv
=
MutableConv
()
def
forward
(
self
,
x
):
return
self
.
conv
(
x
,
self
.
index
())
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
5
,
3
,
3
]))
def
test_value_choice_in_functional
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
dropout_rate
=
nn
.
ValueChoice
([
0.
,
1.
])
def
forward
(
self
,
x
):
return
F
.
dropout
(
x
,
self
.
dropout_rate
())
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnuemrateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
def
test_shared
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
shared
=
True
):
super
().
__init__
()
labels
=
[
'x'
,
'x'
]
if
shared
else
[
None
,
None
]
self
.
module1
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
1
)
],
label
=
labels
[
0
])
self
.
module2
=
nn
.
LayerChoice
([
nn
.
Conv2d
(
3
,
3
,
kernel_size
=
1
),
nn
.
Conv2d
(
3
,
5
,
kernel_size
=
1
)
],
label
=
labels
[
1
])
def
forward
(
self
,
x
):
return
self
.
module1
(
x
)
+
self
.
module2
(
x
)
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
sampler
=
RandomSampler
()
mutator
=
mutators
[
0
].
bind_sampler
(
sampler
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutator
.
apply
(
model
))(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(
0
),
1
)
self
.
assertEqual
(
sampler
.
counter
,
1
)
model
=
self
.
_convert_to_ir
(
Net
(
shared
=
False
))
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
2
)
sampler
=
RandomSampler
()
# repeat test. Expectation: sometimes succeeds, sometimes fails.
failed_count
=
0
for
i
in
range
(
30
):
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertEqual
(
sampler
.
counter
,
2
*
(
i
+
1
))
try
:
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
except
RuntimeError
:
failed_count
+=
1
self
.
assertGreater
(
failed_count
,
0
)
self
.
assertLess
(
failed_count
,
30
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment