Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ae50ed14
"test/ref/dequantizelinear.cpp" did not exist on "d4cfdb3ef74cea0bff5cde9f3041cb359f25b0dc"
Unverified
Commit
ae50ed14
authored
Dec 31, 2020
by
Yuge Zhang
Committed by
GitHub
Dec 31, 2020
Browse files
Refactor wrap module as "blackbox_module" (#3238)
parent
15da19d3
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
766 additions
and
743 deletions
+766
-743
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-1
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+2
-2
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+459
-479
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+10
-6
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+117
-149
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+2
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+66
-49
test/.gitignore
test/.gitignore
+1
-0
test/retiarii_test/darts/darts_model.py
test/retiarii_test/darts/darts_model.py
+2
-5
test/retiarii_test/darts/ops.py
test/retiarii_test/darts/ops.py
+15
-15
test/retiarii_test/darts/test.py
test/retiarii_test/darts/test.py
+5
-5
test/retiarii_test/darts/test_oneshot.py
test/retiarii_test/darts/test_oneshot.py
+2
-2
test/retiarii_test/mnasnet/base_mnasnet.py
test/retiarii_test/mnasnet/base_mnasnet.py
+28
-22
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+6
-6
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+50
-0
No files found.
nni/retiarii/__init__.py
View file @
ae50ed14
...
...
@@ -2,4 +2,4 @@ from .operation import Operation
from
.graph
import
*
from
.execution
import
*
from
.mutator
import
*
from
.utils
import
register_module
\ No newline at end of file
from
.utils
import
blackbox
,
blackbox_module
,
register_trainer
nni/retiarii/codegen/pytorch.py
View file @
ae50ed14
...
...
@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
_logger
.
info
(
'sorted_incoming_edges: %s'
,
str
(
edges
))
_logger
.
debug
(
'sorted_incoming_edges: %s'
,
str
(
edges
))
if
not
edges
:
return
[]
_logger
.
info
(
'all tail_slots are None: %s'
,
str
([
edge
.
tail_slot
for
edge
in
edges
]))
_logger
.
debug
(
'all tail_slots are None: %s'
,
str
([
edge
.
tail_slot
for
edge
in
edges
]))
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
return
edges
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
...
...
nni/retiarii/converter/graph_gen.py
View file @
ae50ed14
...
...
@@ -6,17 +6,20 @@ import torch
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
from
..utils
import
get_records
from
.op_types
import
MODULE_EXCEPT_LIST
,
BasicOpsPT
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
_logger
=
logging
.
getLogger
(
__name__
)
global_seq
=
0
global_graph_id
=
0
modules_arg
=
None
class
GraphConverter
:
def
__init__
(
self
):
self
.
global_seq
=
0
self
.
global_graph_id
=
0
self
.
modules_arg
=
get_records
()
def
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
def
_add_edge
(
self
,
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
Parameters
----------
...
...
@@ -76,29 +79,24 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap,
new_node_input_idx
+=
1
def
create_prim_constant_node
(
ir_graph
,
node
,
module_name
):
global
global_seq
def
create_prim_constant_node
(
self
,
ir_graph
,
node
,
module_name
):
attrs
=
{}
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
global_seq
),
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
self
.
global_seq
),
node
.
kind
(),
attrs
)
return
new_node
def
handle_prim_attr_node
(
node
):
def
handle_prim_attr_node
(
self
,
node
):
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
return
node
.
kind
(),
attrs
def
_remove_mangle
(
module_type_str
):
def
_remove_mangle
(
self
,
module_type_str
):
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
def
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
None
):
def
remove_unconnected_nodes
(
self
,
ir_graph
,
targeted_type
=
None
):
"""
Parameters
----------
...
...
@@ -126,8 +124,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None):
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
...
...
@@ -234,13 +231,12 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
Node
the created node ir
"""
global
global_seq
if
node
.
kind
()
==
'prim::CallMethod'
:
# get and handle the first input, which should be an nn.Module
assert
node
.
hasAttribute
(
'name'
)
if
node
.
s
(
'name'
)
==
'forward'
:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str
=
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule
=
node
.
inputsAt
(
0
).
node
()
assert
submodule
.
kind
()
==
'prim::GetAttr'
assert
submodule
.
hasAttribute
(
'name'
)
...
...
@@ -258,7 +254,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
convert_module
(
script_module
.
_modules
[
submodule_name
],
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
...
...
@@ -276,7 +272,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
...
...
@@ -296,45 +292,45 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
node_index
[
node
]
=
subcell
# connect the cell into graph
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
else
:
raise
RuntimeError
(
'unsupported CallMethod {}'
.
format
(
node
.
s
(
'name'
)))
elif
node
.
kind
()
==
'prim::CallFunction'
:
func_type_str
=
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func
=
node
.
inputsAt
(
0
).
node
()
assert
func
.
kind
()
==
'prim::Constant'
assert
func
.
hasAttribute
(
'name'
)
func_name
=
func
.
s
(
'name'
)
# create node for func
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
global_seq
),
self
.
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
self
.
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
node_index
[
node
]
=
func_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
new_node
=
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
new_node
=
self
.
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
global_seq
),
node
.
kind
())
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
handle_prim_attr_node
(
node
)
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
global_seq
),
node_type
,
attrs
=
self
.
handle_prim_attr_node
(
node
)
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
self
.
global_seq
),
node_type
,
attrs
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::If'
:
...
...
@@ -354,8 +350,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
return
node_index
def
merge_aten_slices
(
ir_graph
):
def
merge_aten_slices
(
self
,
ir_graph
):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
...
...
@@ -401,36 +396,31 @@ def merge_aten_slices(ir_graph):
edge
.
head
=
new_slice_node
ir_graph
.
hidden_nodes
.
remove
(
node
)
def
refine_graph
(
ir_graph
):
def
refine_graph
(
self
,
ir_graph
):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::Constant'
)
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
module
):
global
modules_arg
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::Constant'
)
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
self
.
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
self
,
module
):
m_attrs
=
{}
candidates
=
module
.
candidate
_op
s
candidates
=
module
.
op_
candidates
choices
=
[]
for
cand
in
candidates
:
assert
id
(
cand
)
in
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
modules_arg
[
id
(
cand
)],
dict
)
assert
id
(
cand
)
in
self
.
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
self
.
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
modules_arg
[
id
(
cand
)]})
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
self
.
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
_handle_inputchoice
(
module
):
def
_handle_inputchoice
(
self
,
module
):
m_attrs
=
{}
m_attrs
[
'n_candidates'
]
=
module
.
n_candidates
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
...
...
@@ -438,8 +428,7 @@ def _handle_inputchoice(module):
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
convert_module
(
script_module
,
module
,
module_name
,
ir_model
):
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
...
...
@@ -461,34 +450,36 @@ def convert_module(script_module, module, module_name, ir_model):
dict
the input arguments of this module
"""
global
global_graph_id
global
modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
if
original_type_name
==
OpTypeName
.
LayerChoice
:
m_attrs
=
_handle_layerchoice
(
module
)
return
None
,
m_attrs
if
original_type_name
==
OpTypeName
.
Input
Choice
:
m_attrs
=
_handle_
input
choice
(
module
)
return
None
,
m_attrs
if
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
if
original_type_name
in
torch
.
nn
.
__dict__
and
original_type_name
not
in
MODULE_EXCEPT_LIST
:
m_attrs
=
None
if
original_type_name
in
MODULE_EXCEPT_LIST
:
pass
# do nothing
el
if
original_type_name
==
OpTypeName
.
Layer
Choice
:
m_attrs
=
self
.
_handle_
layer
choice
(
module
)
elif
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
self
.
_handle_inputchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
el
if
original_type_name
in
torch
.
nn
.
__dict__
:
# this is a basic module from pytorch, no need to parse its graph
assert
id
(
module
)
in
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
modules_arg
[
id
(
module
)]
assert
id
(
module
)
in
self
.
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
elif
id
(
module
)
in
self
.
modules_arg
:
# this module is marked as blackbox, won't continue to parse
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
if
m_attrs
is
not
None
:
return
None
,
m_attrs
# handle TorchScript graph
sm_graph
=
script_module
.
graph
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
self
.
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
self
.
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
# handle graph nodes
node_index
=
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
node_index
=
self
.
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
)
# handle graph outputs
...
...
@@ -502,22 +493,14 @@ def convert_module(script_module, module, module_name, ir_model):
ir_graph
.
add_edge
(
head
=
(
node_index
[
_output
.
node
()],
src_node_idx
),
tail
=
(
ir_graph
.
output_node
,
None
))
refine_graph
(
ir_graph
)
self
.
refine_graph
(
ir_graph
)
ir_graph
.
_register
()
if
id
(
module
)
not
in
modules_arg
:
raise
RuntimeError
(
f
'
{
original_type_name
}
arguments are not recorded,
\
you might have forgotten to decorate this class with @register_module()'
)
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return
ir_graph
,
{}
def
convert_to_graph
(
script_module
,
module
,
recorded_modules_arg
):
def
convert_to_graph
(
script_module
,
module
):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
...
...
@@ -527,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns
-------
Model
the constructed IR model
"""
global
modules_arg
modules_arg
=
recorded_modules_arg
model
=
Model
(
_internal
=
True
)
module_name
=
'_model'
convert_module
(
script_module
,
module
,
module_name
,
model
)
GraphConverter
().
convert_module
(
script_module
,
module
,
module_name
,
model
)
return
model
nni/retiarii/experiment.py
View file @
ae50ed14
...
...
@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__)
OneShotTrainers
=
(
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
...
...
@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment):
except
Exception
as
e
:
_logger
.
error
(
'Your base model cannot be parsed by torch.jit.script, please fix the following error:'
)
raise
e
base_model
=
convert_to_graph
(
script_module
,
self
.
base_model
,
self
.
recorded_module_args
)
base_model
_ir
=
convert_to_graph
(
script_module
,
self
.
base_model
)
assert
id
(
self
.
trainer
)
in
self
.
recorded_module_args
trainer_config
=
self
.
recorded_module_args
[
id
(
self
.
trainer
)]
base_model
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
recorded_module_args
=
get_records
()
if
id
(
self
.
trainer
)
not
in
recorded_module_args
:
raise
KeyError
(
'Your trainer is not found in registered classes. You might have forgotten to
\
register your customized trainer with @register_trainer decorator.'
)
trainer_config
=
recorded_module_args
[
id
(
self
.
trainer
)]
base_model_ir
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
# handle inline mutations
mutators
=
self
.
_process_inline_mutation
(
base_model
)
mutators
=
self
.
_process_inline_mutation
(
base_model
_ir
)
if
mutators
is
not
None
and
self
.
applied_mutators
:
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators,
\
do not use mutators when you use LayerChoice/InputChoice'
)
...
...
@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment):
self
.
applied_mutators
=
mutators
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
)).
start
()
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
_ir
,
self
.
applied_mutators
)).
start
()
_logger
.
info
(
'Strategy started!'
)
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
ae50ed14
import
inspect
import
logging
from
typing
import
Any
,
List
import
torch
import
torch.nn
as
nn
from
...utils
import
add_record
,
version_larger_equal
from
...utils
import
add_record
,
blackbox_module
,
uid
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'):
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'ChannelShuffle'
class
LayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
op_candidates
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
super
(
LayerChoice
,
self
).
__init__
()
self
.
candidate
_op
s
=
op_candidates
self
.
label
=
key
self
.
key
=
key
# deprecated, for backward compatibility
self
.
op_
candidates
=
op_candidates
self
.
label
=
key
if
key
is
not
None
else
f
'layerchoice_
{
uid
()
}
'
self
.
key
=
self
.
label
# deprecated, for backward compatibility
for
i
,
module
in
enumerate
(
op_candidates
):
# deprecated, for backward compatibility
self
.
add_module
(
str
(
i
),
module
)
if
reduction
or
return_mask
:
...
...
@@ -66,8 +62,8 @@ class InputChoice(nn.Module):
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
label
=
key
self
.
key
=
key
# deprecated, for backward compatibility
self
.
label
=
key
if
key
is
not
None
else
f
'inputchoice_
{
uid
()
}
'
self
.
key
=
self
.
label
# deprecated, for backward compatibility
if
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
...
...
@@ -101,6 +97,7 @@ class Placeholder(nn.Module):
class
ChosenInputs
(
nn
.
Module
):
"""
"""
def
__init__
(
self
,
chosen
:
List
[
int
],
reduction
:
str
):
super
().
__init__
()
self
.
chosen
=
chosen
...
...
@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules
class
Module
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Module
,
self
).
__init__
()
Module
=
nn
.
Module
class
Sequential
(
nn
.
Sequential
):
...
...
@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList):
super
(
ModuleList
,
self
).
__init__
(
*
args
)
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
ELU
=
wrap_module
(
nn
.
ELU
)
SELU
=
wrap_module
(
nn
.
SELU
)
CELU
=
wrap_module
(
nn
.
CELU
)
GLU
=
wrap_module
(
nn
.
GLU
)
GELU
=
wrap_module
(
nn
.
GELU
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
RNN
=
wrap_module
(
nn
.
RNN
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
GRU
=
wrap_module
(
nn
.
GRU
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
Flatten
=
wrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
Identity
=
blackbox_module
(
nn
.
Identity
)
Linear
=
blackbox_module
(
nn
.
Linear
)
Conv1d
=
blackbox_module
(
nn
.
Conv1d
)
Conv2d
=
blackbox_module
(
nn
.
Conv2d
)
Conv3d
=
blackbox_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
blackbox_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
blackbox_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
blackbox_module
(
nn
.
ConvTranspose3d
)
Threshold
=
blackbox_module
(
nn
.
Threshold
)
ReLU
=
blackbox_module
(
nn
.
ReLU
)
Hardtanh
=
blackbox_module
(
nn
.
Hardtanh
)
ReLU6
=
blackbox_module
(
nn
.
ReLU6
)
Sigmoid
=
blackbox_module
(
nn
.
Sigmoid
)
Tanh
=
blackbox_module
(
nn
.
Tanh
)
Softmax
=
blackbox_module
(
nn
.
Softmax
)
Softmax2d
=
blackbox_module
(
nn
.
Softmax2d
)
LogSoftmax
=
blackbox_module
(
nn
.
LogSoftmax
)
ELU
=
blackbox_module
(
nn
.
ELU
)
SELU
=
blackbox_module
(
nn
.
SELU
)
CELU
=
blackbox_module
(
nn
.
CELU
)
GLU
=
blackbox_module
(
nn
.
GLU
)
GELU
=
blackbox_module
(
nn
.
GELU
)
Hardshrink
=
blackbox_module
(
nn
.
Hardshrink
)
LeakyReLU
=
blackbox_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
blackbox_module
(
nn
.
LogSigmoid
)
Softplus
=
blackbox_module
(
nn
.
Softplus
)
Softshrink
=
blackbox_module
(
nn
.
Softshrink
)
MultiheadAttention
=
blackbox_module
(
nn
.
MultiheadAttention
)
PReLU
=
blackbox_module
(
nn
.
PReLU
)
Softsign
=
blackbox_module
(
nn
.
Softsign
)
Softmin
=
blackbox_module
(
nn
.
Softmin
)
Tanhshrink
=
blackbox_module
(
nn
.
Tanhshrink
)
RReLU
=
blackbox_module
(
nn
.
RReLU
)
AvgPool1d
=
blackbox_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
blackbox_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
blackbox_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
blackbox_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
blackbox_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
blackbox_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
blackbox_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
blackbox_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
blackbox_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
blackbox_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
blackbox_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
blackbox_module
(
nn
.
LPPool1d
)
LPPool2d
=
blackbox_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
blackbox_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
blackbox_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
blackbox_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
blackbox_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
blackbox_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
blackbox_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
blackbox_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
blackbox_module
(
nn
.
LayerNorm
)
GroupNorm
=
blackbox_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
blackbox_module
(
nn
.
SyncBatchNorm
)
Dropout
=
blackbox_module
(
nn
.
Dropout
)
Dropout2d
=
blackbox_module
(
nn
.
Dropout2d
)
Dropout3d
=
blackbox_module
(
nn
.
Dropout3d
)
AlphaDropout
=
blackbox_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
blackbox_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
blackbox_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
blackbox_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
blackbox_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
blackbox_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
blackbox_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
blackbox_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
blackbox_module
(
nn
.
Embedding
)
EmbeddingBag
=
blackbox_module
(
nn
.
EmbeddingBag
)
RNNBase
=
blackbox_module
(
nn
.
RNNBase
)
RNN
=
blackbox_module
(
nn
.
RNN
)
LSTM
=
blackbox_module
(
nn
.
LSTM
)
GRU
=
blackbox_module
(
nn
.
GRU
)
RNNCellBase
=
blackbox_module
(
nn
.
RNNCellBase
)
RNNCell
=
blackbox_module
(
nn
.
RNNCell
)
LSTMCell
=
blackbox_module
(
nn
.
LSTMCell
)
GRUCell
=
blackbox_module
(
nn
.
GRUCell
)
PixelShuffle
=
blackbox_module
(
nn
.
PixelShuffle
)
Upsample
=
blackbox_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
blackbox_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
blackbox_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
blackbox_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
blackbox_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
blackbox_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
blackbox_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
blackbox_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
blackbox_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
blackbox_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
blackbox_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
blackbox_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
blackbox_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
blackbox_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
blackbox_module
(
nn
.
ConstantPad3d
)
Bilinear
=
blackbox_module
(
nn
.
Bilinear
)
CosineSimilarity
=
blackbox_module
(
nn
.
CosineSimilarity
)
Unfold
=
blackbox_module
(
nn
.
Unfold
)
Fold
=
blackbox_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
blackbox_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
blackbox_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
blackbox_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
blackbox_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
blackbox_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
blackbox_module
(
nn
.
Transformer
)
Flatten
=
blackbox_module
(
nn
.
Flatten
)
Hardsigmoid
=
blackbox_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
wrap
_module
(
nn
.
Hardswish
)
Hardswish
=
blackbox
_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
wrap_module
(
nn
.
SiLU
)
Unflatten
=
wrap_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
wrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
\ No newline at end of file
SiLU
=
blackbox_module
(
nn
.
SiLU
)
Unflatten
=
blackbox_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
blackbox_module
(
nn
.
TripletMarginWithDistanceLoss
)
nni/retiarii/trainer/pytorch/base.py
View file @
ae50ed14
...
...
@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return
None
@
register_trainer
()
@
register_trainer
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
"""
Image classification trainer for PyTorch.
...
...
@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super
(
PyTorchImageClassificationTrainer
,
self
).
__init__
()
super
().
__init__
()
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
model
=
model
if
self
.
_use_cuda
:
...
...
nni/retiarii/utils.py
View file @
ae50ed14
import
inspect
import
warnings
from
collections
import
defaultdict
from
typing
import
Any
...
...
@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any:
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
def
version_larger_equal
(
a
:
str
,
b
:
str
)
->
bool
:
# TODO: refactor later
a
=
a
.
split
(
'+'
)[
0
]
b
=
b
.
split
(
'+'
)[
0
]
return
tuple
(
map
(
int
,
a
.
split
(
'.'
)))
>=
tuple
(
map
(
int
,
b
.
split
(
'.'
)))
_records
=
{}
...
...
@@ -29,73 +32,87 @@ def add_record(key, value):
"""
global
_records
if
_records
is
not
None
:
#
assert key not in _records, '{} already in _records'.format(key)
assert
key
not
in
_records
,
'{} already in _records'
.
format
(
key
)
_records
[
key
]
=
value
def
_register_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
def
del_record
(
key
):
global
_records
if
_records
is
not
None
:
_records
.
pop
(
key
,
None
)
def
__init__
(
self
,
*
args
,
**
kws
):
def
_blackbox_cls
(
cls
,
module_name
,
register_format
=
None
):
class
wrapper
(
cls
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
argname_list
=
list
(
inspect
.
signature
(
cls
).
parameters
.
keys
())
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
full_args
.
update
(
kwargs
)
assert
len
(
args
)
<=
len
(
argname_list
),
f
'Length of
{
args
}
is greater than length of
{
argname_list
}
.'
for
argname
,
value
in
zip
(
argname_list
,
args
):
full_args
[
argname
]
=
value
# eject un-serializable arguments
for
k
in
list
(
full_args
.
keys
()):
# The list is not complete and does not support nested cases.
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
)):
if
not
(
register_format
==
'full'
and
k
==
'model'
):
# no warning if it is base model in trainer
warnings
.
warn
(
f
'
{
cls
}
has un-serializable arguments
{
k
}
whose value is
{
full_args
[
k
]
}
.
\
This is not supported. You can ignore this warning if you are passing the model to trainer.'
)
full_args
.
pop
(
k
)
if
register_format
==
'args'
:
add_record
(
id
(
self
),
full_args
)
elif
register_format
==
'full'
:
full_class_name
=
cls
.
__module__
+
'.'
+
cls
.
__name__
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
super
().
__init__
(
*
args
,
**
kwargs
)
def
register_module
():
"""
Register a module.
"""
# use it as a decorator: @register_module()
def
_register
(
cls
):
m
=
_register_module
(
original_class
=
cls
)
return
m
def
__del__
(
self
):
del_record
(
id
(
self
))
return
_register
# using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
# instead of simply putting torch.nn or etc.
wrapper
.
__module__
=
module_name
wrapper
.
__name__
=
cls
.
__name__
wrapper
.
__qualname__
=
cls
.
__qualname__
wrapper
.
__init__
.
__doc__
=
cls
.
__init__
.
__doc__
return
wrapper
def
_register_trainer
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
full_class_name
=
original_class
.
__module__
+
'.'
+
original_class
.
__name__
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
# TODO: support both pytorch and tensorflow
from
.nn.pytorch
import
Module
if
isinstance
(
args
[
i
],
Module
):
# ignore the base model object
continue
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
def
blackbox
(
cls
,
*
args
,
**
kwargs
):
"""
To create an blackbox instance inline without decorator. For example,
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'args'
)(
*
args
,
**
kwargs
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
def
blackbox_module
(
cls
):
"""
Register a module. Use it as a decorator.
"""
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'args'
)
def
register_trainer
():
def
_register
(
cls
):
m
=
_register_trainer
(
original_class
=
cls
)
return
m
return
_register
def
register_trainer
(
cls
):
"""
Register a trainer. Use it as a decorator.
"""
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'full'
)
_last_uid
=
defaultdict
(
int
)
...
...
test/.gitignore
View file @
ae50ed14
...
...
@@ -5,6 +5,7 @@ tuner_result.txt
assessor_result.txt
_generated_model.py
_generated_model_*.py
data
generated
test/retiarii_test/darts/darts_model.py
View file @
ae50ed14
...
...
@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import
ops
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register_module
from
nni.retiarii
import
blackbox_module
@
blackbox_module
class
AuxiliaryHead
(
nn
.
Module
):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
...
...
@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module):
logits
=
self
.
linear
(
out
)
return
logits
@
register_module
()
class
Node
(
nn
.
Module
):
def
__init__
(
self
,
node_id
,
num_prev_nodes
,
channels
,
num_downsample_connect
):
super
().
__init__
()
...
...
@@ -66,7 +65,6 @@ class Node(nn.Module):
#out = [self.drop_path(o) if o is not None else None for o in out]
return
self
.
input_switch
(
out
)
@
register_module
()
class
Cell
(
nn
.
Module
):
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
...
...
@@ -100,7 +98,6 @@ class Cell(nn.Module):
output
=
torch
.
cat
(
new_tensors
,
dim
=
1
)
return
output
@
register_module
()
class
CNN
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
...
...
test/retiarii_test/darts/ops.py
View file @
ae50ed14
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register
_module
from
nni.retiarii
import
blackbox
_module
@
register
_module
()
@
blackbox
_module
class
DropPath
(
nn
.
Module
):
def
__init__
(
self
,
p
=
0.
):
"""
...
...
@@ -12,7 +12,7 @@ class DropPath(nn.Module):
p : float
Probability of an path to be zeroed.
"""
super
(
DropPath
,
self
).
__init__
()
super
().
__init__
()
self
.
p
=
p
def
forward
(
self
,
x
):
...
...
@@ -24,13 +24,13 @@ class DropPath(nn.Module):
return
x
@
register
_module
()
@
blackbox
_module
class
PoolBN
(
nn
.
Module
):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
PoolBN
,
self
).
__init__
()
super
().
__init__
()
if
pool_type
.
lower
()
==
'max'
:
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
,
stride
,
padding
)
elif
pool_type
.
lower
()
==
'avg'
:
...
...
@@ -45,13 +45,13 @@ class PoolBN(nn.Module):
out
=
self
.
bn
(
out
)
return
out
@
register
_module
()
@
blackbox
_module
class
StdConv
(
nn
.
Module
):
"""
Standard conv: ReLU - Conv - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
StdConv
,
self
).
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
bias
=
False
),
...
...
@@ -61,13 +61,13 @@ class StdConv(nn.Module):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
register
_module
()
@
blackbox
_module
class
FacConv
(
nn
.
Module
):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
super
(
FacConv
,
self
).
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
(
kernel_length
,
1
),
stride
,
padding
,
bias
=
False
),
...
...
@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
register
_module
()
@
blackbox
_module
class
DilConv
(
nn
.
Module
):
"""
(Dilated) depthwise separable conv.
...
...
@@ -86,7 +86,7 @@ class DilConv(nn.Module):
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
super
(
DilConv
,
self
).
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
...
...
@@ -98,14 +98,14 @@ class DilConv(nn.Module):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
register
_module
()
@
blackbox
_module
class
SepConv
(
nn
.
Module
):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
SepConv
,
self
).
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
DilConv
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
1
,
affine
=
affine
),
DilConv
(
C_in
,
C_out
,
kernel_size
,
1
,
padding
,
dilation
=
1
,
affine
=
affine
)
...
...
@@ -114,13 +114,13 @@ class SepConv(nn.Module):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
register
_module
()
@
blackbox
_module
class
FactorizedReduce
(
nn
.
Module
):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
super
(
FactorizedReduce
,
self
).
__init__
()
super
().
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
conv1
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
...
...
test/retiarii_test/darts/test.py
View file @
ae50ed14
...
...
@@ -31,4 +31,4 @@ if __name__ == '__main__':
exp_config
.
training_service
.
use_active_gpu
=
True
exp_config
.
training_service
.
gpu_indices
=
[
1
,
2
]
exp
.
run
(
exp_config
,
8081
,
debug
=
True
)
exp
.
run
(
exp_config
,
8081
)
test/retiarii_test/darts/test_oneshot.py
View file @
ae50ed14
...
...
@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0):
valid_transform
=
transforms
.
Compose
(
normalize
)
if
cls
==
"cifar10"
:
dataset_train
=
CIFAR10
(
root
=
"./data"
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
dataset_valid
=
CIFAR10
(
root
=
"./data"
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
dataset_train
=
CIFAR10
(
root
=
"./data
/cifar10
"
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
dataset_valid
=
CIFAR10
(
root
=
"./data
/cifar10
"
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
else
:
raise
NotImplementedError
return
dataset_train
,
dataset_valid
...
...
test/retiarii_test/mnasnet/base_mnasnet.py
View file @
ae50ed14
from
nni.retiarii
import
blackbox_module
import
nni.retiarii.nn.pytorch
as
nn
import
warnings
import
torch
...
...
@@ -8,8 +10,6 @@ import torch.nn.functional as F
import
sys
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
resolve
().
parents
[
2
]))
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
...
...
@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
+
x
class
_InvertedResidual
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
,
kernel_size
,
stride
,
expansion_factor
,
skip
,
bn_momentum
=
0.1
):
...
...
@@ -110,7 +111,7 @@ def _get_depths(depths, alpha):
rather than down. """
return
[
_round_to_multiple_of
(
depth
*
alpha
,
8
)
for
depth
in
depths
]
@
register_module
()
class
MNASNet
(
nn
.
Module
):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
...
...
@@ -127,7 +128,7 @@ class MNASNet(nn.Module):
def
__init__
(
self
,
alpha
,
depths
,
convops
,
kernel_sizes
,
num_layers
,
skips
,
num_classes
=
1000
,
dropout
=
0.2
):
super
(
MNASNet
,
self
).
__init__
()
super
().
__init__
()
assert
alpha
>
0.0
assert
len
(
depths
)
==
len
(
convops
)
==
len
(
kernel_sizes
)
==
len
(
num_layers
)
==
len
(
skips
)
==
7
self
.
alpha
=
alpha
...
...
@@ -143,7 +144,7 @@ class MNASNet(nn.Module):
nn
.
ReLU
(
inplace
=
True
),
]
count
=
0
#for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
#
for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for
filter_size
,
exp_ratio
,
stride
in
zip
(
base_filter_sizes
,
exp_ratios
,
strides
):
# TODO: restrict that "choose" can only be used within mutator
...
...
@@ -153,7 +154,7 @@ class MNASNet(nn.Module):
'op_type_options'
:
[
'__mutated__.base_mnasnet.RegularConv'
,
'__mutated__.base_mnasnet.DepthwiseConv'
,
'__mutated__.base_mnasnet.MobileConv'
],
#'se_ratio_options': [0, 0.25],
#
'se_ratio_options': [0, 0.25],
'skip_options'
:
[
'identity'
,
'no'
],
'n_filter_options'
:
[
int
(
filter_size
*
x
)
for
x
in
[
0.75
,
1.0
,
1.25
]],
'exp_ratio'
:
exp_ratio
,
...
...
@@ -185,7 +186,7 @@ class MNASNet(nn.Module):
#self.for_test = 10
def
forward
(
self
,
x
):
#if self.for_test == 10:
#
if self.for_test == 10:
x
=
self
.
layers
(
x
)
# Equivalent to global avgpool and removing H and W dimensions.
x
=
x
.
mean
([
2
,
3
])
...
...
@@ -211,9 +212,11 @@ class MNASNet(nn.Module):
def
test_model
(
model
):
model
(
torch
.
randn
(
2
,
3
,
224
,
224
))
#====================definition of candidate op classes
# ====================definition of candidate op classes
BN_MOMENTUM
=
1
-
0.9997
class
RegularConv
(
nn
.
Module
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
super
().
__init__
()
...
...
@@ -234,6 +237,7 @@ class RegularConv(nn.Module):
out
=
out
+
x
return
out
class
DepthwiseConv
(
nn
.
Module
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
super
().
__init__
()
...
...
@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module):
out
=
out
+
x
return
out
class
MobileConv
(
nn
.
Module
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
super
().
__init__
()
...
...
@@ -274,7 +279,7 @@ class MobileConv(nn.Module):
nn
.
BatchNorm2d
(
mid_ch
,
momentum
=
BN_MOMENTUM
),
nn
.
ReLU
(
inplace
=
True
),
# Depthwise
nn
.
Conv2d
(
mid_ch
,
mid_ch
,
kernel_size
,
padding
=
(
kernel_size
-
1
)
//
2
,
nn
.
Conv2d
(
mid_ch
,
mid_ch
,
kernel_size
,
padding
=
(
kernel_size
-
1
)
//
2
,
stride
=
stride
,
groups
=
mid_ch
,
bias
=
False
),
nn
.
BatchNorm2d
(
mid_ch
,
momentum
=
BN_MOMENTUM
),
nn
.
ReLU
(
inplace
=
True
),
...
...
@@ -288,5 +293,6 @@ class MobileConv(nn.Module):
out
=
out
+
x
return
out
# mnasnet0_5
ir_module
=
_InvertedResidual
(
16
,
16
,
3
,
1
,
1
,
True
)
test/retiarii_test/mnasnet/test.py
View file @
ae50ed14
...
...
@@ -41,4 +41,4 @@ if __name__ == '__main__':
exp_config
.
max_trial_number
=
10
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
8081
,
debug
=
True
)
exp
.
run
(
exp_config
,
8081
)
test/retiarii_test/mnist/test.py
0 → 100644
View file @
ae50ed14
import
random
import
nni.retiarii.nn.pytorch
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii.experiment
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.strategies
import
RandomStrategy
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
Net
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
nn
.
LayerChoice
([
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
),
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
,
bias
=
False
)
])
self
.
fc2
=
nn
.
Linear
(
hidden_size
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
if
__name__
==
'__main__'
:
base_model
=
Net
(
128
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"MNIST"
,
dataset_kwargs
=
{
"root"
:
"data/mnist"
,
"download"
:
True
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
simple_startegy
=
RandomStrategy
()
exp
=
RetiariiExperiment
(
base_model
,
trainer
,
[],
simple_startegy
)
exp_config
=
RetiariiExeConfig
(
'local'
)
exp_config
.
experiment_name
=
'mnist_search'
exp_config
.
trial_concurrency
=
2
exp_config
.
max_trial_number
=
10
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
8081
+
random
.
randint
(
0
,
100
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment