Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ae50ed14
Unverified
Commit
ae50ed14
authored
Dec 31, 2020
by
Yuge Zhang
Committed by
GitHub
Dec 31, 2020
Browse files
Refactor wrap module as "blackbox_module" (#3238)
parent
15da19d3
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
766 additions
and
743 deletions
+766
-743
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-1
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+2
-2
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+459
-479
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+10
-6
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+117
-149
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+2
-2
nni/retiarii/utils.py
nni/retiarii/utils.py
+66
-49
test/.gitignore
test/.gitignore
+1
-0
test/retiarii_test/darts/darts_model.py
test/retiarii_test/darts/darts_model.py
+2
-5
test/retiarii_test/darts/ops.py
test/retiarii_test/darts/ops.py
+15
-15
test/retiarii_test/darts/test.py
test/retiarii_test/darts/test.py
+5
-5
test/retiarii_test/darts/test_oneshot.py
test/retiarii_test/darts/test_oneshot.py
+2
-2
test/retiarii_test/mnasnet/base_mnasnet.py
test/retiarii_test/mnasnet/base_mnasnet.py
+28
-22
test/retiarii_test/mnasnet/test.py
test/retiarii_test/mnasnet/test.py
+6
-6
test/retiarii_test/mnist/test.py
test/retiarii_test/mnist/test.py
+50
-0
No files found.
nni/retiarii/__init__.py
View file @
ae50ed14
...
@@ -2,4 +2,4 @@ from .operation import Operation
...
@@ -2,4 +2,4 @@ from .operation import Operation
from
.graph
import
*
from
.graph
import
*
from
.execution
import
*
from
.execution
import
*
from
.mutator
import
*
from
.mutator
import
*
from
.utils
import
register_module
from
.utils
import
blackbox
,
blackbox_module
,
register_trainer
\ No newline at end of file
nni/retiarii/codegen/pytorch.py
View file @
ae50ed14
...
@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
...
@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
_logger
.
info
(
'sorted_incoming_edges: %s'
,
str
(
edges
))
_logger
.
debug
(
'sorted_incoming_edges: %s'
,
str
(
edges
))
if
not
edges
:
if
not
edges
:
return
[]
return
[]
_logger
.
info
(
'all tail_slots are None: %s'
,
str
([
edge
.
tail_slot
for
edge
in
edges
]))
_logger
.
debug
(
'all tail_slots are None: %s'
,
str
([
edge
.
tail_slot
for
edge
in
edges
]))
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
return
edges
return
edges
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
...
...
nni/retiarii/converter/graph_gen.py
View file @
ae50ed14
...
@@ -6,518 +6,501 @@ import torch
...
@@ -6,518 +6,501 @@ import torch
from
..graph
import
Graph
,
Model
,
Node
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
from
..operation
import
Cell
from
..utils
import
get_records
from
.op_types
import
MODULE_EXCEPT_LIST
,
BasicOpsPT
,
OpTypeName
from
.op_types
import
MODULE_EXCEPT_LIST
,
BasicOpsPT
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
from
.utils
import
_convert_name
,
build_full_name
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
global_seq
=
0
global_graph_id
=
0
modules_arg
=
None
class
GraphConverter
:
def
__init__
(
self
):
self
.
global_seq
=
0
self
.
global_graph_id
=
0
self
.
modules_arg
=
get_records
()
def
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
def
_add_edge
(
self
,
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
"""
Parameters
Parameters
----------
----------
ir_graph : Graph
ir_graph : Graph
node : torch._C.Node
node : torch._C.Node
graph_inputs : List[torch._C.Value]
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
a list of a script graph's inputs
node_index : Dict
node_index : Dict
new_node : Node
new_node : Node
newly created ir node corresponding to `node`
newly created ir node corresponding to `node`
output_remap : Dict
output_remap : Dict
ignore_first : bool
ignore_first : bool
if it is true, skip the first input
if it is true, skip the first input
"""
"""
is_single_input
=
(
len
([
_input
for
_input
in
node
.
inputs
()])
-
(
1
if
ignore_first
else
0
))
==
1
is_single_input
=
(
len
([
_input
for
_input
in
node
.
inputs
()])
-
(
1
if
ignore_first
else
0
))
==
1
new_node_input_idx
=
0
new_node_input_idx
=
0
for
_input
in
node
.
inputs
():
for
_input
in
node
.
inputs
():
if
ignore_first
:
if
ignore_first
:
ignore_first
=
False
ignore_first
=
False
continue
continue
# handle source node
# handle source node
if
_input
in
graph_inputs
:
if
_input
in
graph_inputs
:
idx
=
graph_inputs
.
index
(
_input
)
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
src_node_idx
=
idx
elif
_input
in
output_remap
:
elif
_input
in
output_remap
:
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
predecessor_node
=
output_remap
[
_input
]
predecessor_node
=
output_remap
[
_input
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
None
src_node_idx
=
None
src_node
=
node_index
[
predecessor_node
]
src_node
=
node_index
[
predecessor_node
]
assert
isinstance
(
src_node
,
Node
)
assert
isinstance
(
src_node
,
Node
)
else
:
predecessor_node
=
_input
.
node
()
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs
=
[
_output
for
_output
in
predecessor_node
.
outputs
()]
if
len
(
predecessor_outputs
)
==
1
:
idx
=
None
else
:
else
:
idx
=
predecessor_outputs
.
index
(
_input
)
predecessor_node
=
_input
.
node
()
ir_predecessor_node
=
node_index
[
predecessor_node
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
idx
# find out the index of _input in the outputs of predecessor_node
assert
isinstance
(
ir_predecessor_node
,
Node
)
predecessor_outputs
=
[
_output
for
_output
in
predecessor_node
.
outputs
()]
src_node
=
ir_predecessor_node
if
len
(
predecessor_outputs
)
==
1
:
idx
=
None
# handle destination node
else
:
dst_node
=
new_node
idx
=
predecessor_outputs
.
index
(
_input
)
if
is_single_input
:
ir_predecessor_node
=
node_index
[
predecessor_node
]
dst_node_idx
=
None
src_node_idx
=
idx
else
:
assert
isinstance
(
ir_predecessor_node
,
Node
)
dst_node_idx
=
new_node_input_idx
src_node
=
ir_predecessor_node
# handle destination node
dst_node
=
new_node
if
is_single_input
:
dst_node_idx
=
None
else
:
dst_node_idx
=
new_node_input_idx
# create edge
# create edge
ir_graph
.
add_edge
(
head
=
(
src_node
,
src_node_idx
),
tail
=
(
dst_node
,
dst_node_idx
))
ir_graph
.
add_edge
(
head
=
(
src_node
,
src_node_idx
),
tail
=
(
dst_node
,
dst_node_idx
))
new_node_input_idx
+=
1
new_node_input_idx
+=
1
def
create_prim_constant_node
(
self
,
ir_graph
,
node
,
module_name
):
attrs
=
{}
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
self
.
global_seq
),
node
.
kind
(),
attrs
)
return
new_node
def
create_prim_constant_node
(
ir_graph
,
node
,
module_name
):
def
handle_prim_attr_node
(
self
,
node
):
global
global_seq
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{}
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
return
node
.
kind
(),
attrs
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
global_seq
),
node
.
kind
(),
attrs
)
return
new_node
def
_remove_mangle
(
self
,
module_type_str
):
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
def
handle_prim_attr_node
(
node
):
def
remove_unconnected_nodes
(
self
,
ir_graph
,
targeted_type
=
None
):
assert
node
.
hasAttribute
(
'name'
)
"""
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
Parameters
return
node
.
kind
(),
attrs
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout
=
set
()
for
edge
in
ir_graph
.
edges
:
if
edge
.
head
.
id
not
in
node_fanout
:
node_fanout
.
add
(
edge
.
head
.
id
)
to_removes
=
[]
for
hidden_node
in
ir_graph
.
hidden_nodes
:
if
hidden_node
.
id
not
in
node_fanout
:
assert
isinstance
(
hidden_node
,
Node
)
if
targeted_type
is
None
:
to_removes
.
append
(
hidden_node
)
elif
hidden_node
.
operation
.
type
==
targeted_type
:
to_removes
.
append
(
hidden_node
)
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
def
_remove_mangle
(
module_type_str
):
Returns
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs
=
[]
for
_input
in
sm_graph
.
inputs
():
if
_input
.
debugName
()
==
'self'
:
assert
_input
.
unique
()
==
0
continue
graph_inputs
.
append
(
_input
)
# TODO: add scope name
ir_graph
.
_add_input
(
_convert_name
(
_input
.
debugName
()))
node_index
=
{}
# graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap
=
{}
def
handle_if_condition
(
cond_tensor
):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def
_generate_expr
(
tensor
):
if
tensor
.
node
().
kind
()
==
'prim::GetAttr'
:
return
f
'(
{
getattr
(
module
,
tensor
.
node
().
s
(
"name"
))
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__getitem__'
:
t
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
idx
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
t
}
[
{
idx
}
])'
elif
tensor
.
node
().
kind
()
==
'prim::Constant'
:
return
f
'
{
tensor
.
toIValue
()
}
'
elif
tensor
.
node
().
kind
()
==
'aten::eq'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
==
{
right
}
)'
else
:
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition'
)
expr
=
_generate_expr
(
cond_tensor
)
return
eval
(
expr
)
def
handle_if_node
(
node
):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs
=
[
i
for
i
in
node
.
inputs
()]
assert
len
(
inputs
)
==
1
cond
=
handle_if_condition
(
inputs
[
0
])
chosen_block
=
0
if
cond
else
1
blocks
=
[
block
for
block
in
node
.
blocks
()]
assert
len
(
blocks
)
==
2
last_block_node
=
None
for
node
in
blocks
[
chosen_block
].
nodes
():
last_block_node
=
handle_single_node
(
node
)
return
last_block_node
def
handle_single_node
(
node
):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if
node
.
kind
()
==
'prim::CallMethod'
:
# get and handle the first input, which should be an nn.Module
assert
node
.
hasAttribute
(
'name'
)
if
node
.
s
(
'name'
)
==
'forward'
:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule
=
node
.
inputsAt
(
0
).
node
()
assert
submodule
.
kind
()
==
'prim::GetAttr'
assert
submodule
.
hasAttribute
(
'name'
)
submodule_name
=
submodule
.
s
(
'name'
)
if
submodule
.
inputsAt
(
0
).
debugName
()
==
'self'
:
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
assert
predecessor
.
kind
()
==
'prim::GetAttr'
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
predecessor_name
=
predecessor
.
s
(
'name'
)
# FIXME: exchange
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
# TODO: match subgraph with maintained graphs
# build cell
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
else
:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
node_index
[
node
]
=
subcell
# connect the cell into graph
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
else
:
raise
RuntimeError
(
'unsupported CallMethod {}'
.
format
(
node
.
s
(
'name'
)))
elif
node
.
kind
()
==
'prim::CallFunction'
:
func_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func
=
node
.
inputsAt
(
0
).
node
()
assert
func
.
kind
()
==
'prim::Constant'
assert
func
.
hasAttribute
(
'name'
)
func_name
=
func
.
s
(
'name'
)
# create node for func
self
.
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
self
.
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
node_index
[
node
]
=
func_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
new_node
=
self
.
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
self
.
handle_prim_attr_node
(
node
)
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
self
.
global_seq
),
node_type
,
attrs
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::If'
:
last_block_node
=
handle_if_node
(
node
)
# last_block_node is None means no node in the branch block
node_index
[
node
]
=
last_block_node
elif
node
.
kind
()
==
'prim::Loop'
:
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise
RuntimeError
(
'Loop has not been supported yet!'
)
else
:
raise
RuntimeError
(
'Unsupported kind: {}'
.
format
(
node
.
kind
()))
return
node_index
[
node
]
def
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
None
):
for
node
in
sm_graph
.
nodes
():
"""
handle_single_node
(
node
)
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout
=
set
()
for
edge
in
ir_graph
.
edges
:
if
edge
.
head
.
id
not
in
node_fanout
:
node_fanout
.
add
(
edge
.
head
.
id
)
to_removes
=
[]
for
hidden_node
in
ir_graph
.
hidden_nodes
:
if
hidden_node
.
id
not
in
node_fanout
:
assert
isinstance
(
hidden_node
,
Node
)
if
targeted_type
is
None
:
to_removes
.
append
(
hidden_node
)
elif
hidden_node
.
operation
.
type
==
targeted_type
:
to_removes
.
append
(
hidden_node
)
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
return
node_index
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns
def
merge_aten_slices
(
self
,
ir_graph
):
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs
=
[]
for
_input
in
sm_graph
.
inputs
():
if
_input
.
debugName
()
==
'self'
:
assert
_input
.
unique
()
==
0
continue
graph_inputs
.
append
(
_input
)
# TODO: add scope name
ir_graph
.
_add_input
(
_convert_name
(
_input
.
debugName
()))
node_index
=
{}
# graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap
=
{}
def
handle_if_condition
(
cond_tensor
):
"""
"""
to calculate the condition, we only deal with the following op types by tracing back
if there is aten::slice node, merge the consecutive ones together.
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
"""
def
_generate_expr
(
tensor
):
head_slice_nodes
=
[]
if
tensor
.
node
().
kind
()
==
'prim::GetAttr'
:
has_slice_node
=
False
return
f
'(
{
getattr
(
module
,
tensor
.
node
().
s
(
"name"
))
}
)'
for
node
in
ir_graph
.
hidden_nodes
:
elif
tensor
.
node
().
kind
()
==
'aten::__getitem__'
:
if
node
.
operation
.
type
==
'aten::slice'
:
t
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
has_slice_node
=
True
idx
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
for
pred
in
node
.
predecessors
:
return
f
'(
{
t
}
[
{
idx
}
])'
if
pred
.
operation
.
type
not
in
[
'aten::slice'
,
'prim::Constant'
]:
elif
tensor
.
node
().
kind
()
==
'prim::Constant'
:
head_slice_nodes
.
append
(
node
)
return
f
'
{
tensor
.
toIValue
()
}
'
break
elif
tensor
.
node
().
kind
()
==
'aten::eq'
:
if
has_slice_node
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
assert
head_slice_nodes
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
==
{
right
}
)'
for
head_node
in
head_slice_nodes
:
else
:
slot
=
0
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition'
)
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
expr
=
_generate_expr
(
cond_tensor
)
if
len
(
head_node
.
incoming_edges
)
==
4
:
return
eval
(
expr
)
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert
len
(
head_node
.
incoming_edges
)
==
5
for
edge
in
head_node
.
incoming_edges
:
edge
.
tail
=
new_slice_node
slot
+=
5
node
=
head_node
while
len
(
node
.
successors
)
==
1
and
node
.
successors
[
0
].
operation
.
type
==
'aten::slice'
:
suc_node
=
node
.
successors
[
0
]
assert
len
(
suc_node
.
incoming_edges
)
==
5
for
edge
in
suc_node
.
incoming_edges
:
if
edge
.
tail_slot
==
0
:
edge
.
remove
()
else
:
edge
.
tail
=
new_slice_node
edge
.
tail_slot
=
slot
+
edge
.
tail_slot
-
1
slot
+=
4
ir_graph
.
hidden_nodes
.
remove
(
node
)
node
=
suc_node
for
edge
in
node
.
outgoing_edges
:
edge
.
head
=
new_slice_node
ir_graph
.
hidden_nodes
.
remove
(
node
)
def
handle_if_node
(
node
):
def
refine_graph
(
self
,
ir_graph
):
"""
"""
Parameters
Do the following process to simplify graph:
----------
1. remove unconnected constant node
node : torch._C.Node
2. remove unconnected getattr node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
"""
# only deal with input of prim::If is constant or attribute for now
# some constant is not used, for example, function name as prim::Constant
# will support constant expression in future
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::Constant'
)
inputs
=
[
i
for
i
in
node
.
inputs
()]
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
assert
len
(
inputs
)
==
1
self
.
merge_aten_slices
(
ir_graph
)
cond
=
handle_if_condition
(
inputs
[
0
])
chosen_block
=
0
if
cond
else
1
def
_handle_layerchoice
(
self
,
module
):
blocks
=
[
block
for
block
in
node
.
blocks
()]
m_attrs
=
{}
assert
len
(
blocks
)
==
2
candidates
=
module
.
op_candidates
last_block_node
=
None
choices
=
[]
for
node
in
blocks
[
chosen_block
].
nodes
():
for
cand
in
candidates
:
last_block_node
=
handle_single_node
(
node
)
assert
id
(
cand
)
in
self
.
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
return
last_block_node
assert
isinstance
(
self
.
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
def
handle_single_node
(
node
):
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
self
.
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
_handle_inputchoice
(
self
,
module
):
m_attrs
=
{}
m_attrs
[
'n_candidates'
]
=
module
.
n_candidates
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
"""
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
Parameters
----------
----------
node : torch._C.Node
script_module : torch.jit.RecursiveScriptModule
the node from TorchScript graph
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
Returns
-------
-------
Node
Graph
the created node ir
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
"""
global
global_seq
if
node
.
kind
()
==
'prim::CallMethod'
:
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# get and handle the first input, which should be an nn.Module
# also has LayerChoice or InputChoice or ValueChoice
assert
node
.
hasAttribute
(
'name'
)
original_type_name
=
script_module
.
original_name
if
node
.
s
(
'name'
)
==
'forward'
:
m_attrs
=
None
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
if
original_type_name
in
MODULE_EXCEPT_LIST
:
submodule_type_str
=
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
pass
# do nothing
submodule
=
node
.
inputsAt
(
0
).
node
()
elif
original_type_name
==
OpTypeName
.
LayerChoice
:
assert
submodule
.
kind
()
==
'prim::GetAttr'
m_attrs
=
self
.
_handle_layerchoice
(
module
)
assert
submodule
.
hasAttribute
(
'name'
)
elif
original_type_name
==
OpTypeName
.
InputChoice
:
submodule_name
=
submodule
.
s
(
'name'
)
m_attrs
=
self
.
_handle_inputchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
Placeholder
:
if
submodule
.
inputsAt
(
0
).
debugName
()
==
'self'
:
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
# module is usually instantiated in __init__.
elif
original_type_name
in
torch
.
nn
.
__dict__
:
# when calling a module in forward,
# this is a basic module from pytorch, no need to parse its graph
# prim::GetAttr is used to obtain the module in torch script.
assert
id
(
module
)
in
self
.
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
# therefore, we do this check for a module. example below:
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
elif
id
(
module
)
in
self
.
modules_arg
:
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
# this module is marked as blackbox, won't continue to parse
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
submodule_name
,
script_module
.
_modules
.
keys
())
if
m_attrs
is
not
None
:
return
None
,
m_attrs
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
# handle TorchScript graph
subgraph
,
sub_m_attrs
=
convert_module
(
script_module
.
_modules
[
submodule_name
],
sm_graph
=
script_module
.
graph
submodule_obj
,
self
.
global_graph_id
+=
1
submodule_full_name
,
ir_model
)
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
self
.
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
else
:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# handle graph nodes
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
node_index
=
self
.
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
module_name
,
ir_model
,
ir_graph
)
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
# handle graph outputs
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
for
_output
in
sm_graph
.
outputs
():
assert
predecessor
.
kind
()
==
'prim::GetAttr'
ir_graph
.
_add_output
(
_convert_name
(
_output
.
debugName
()))
assert
predecessor
.
hasAttribute
(
'name'
)
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
if
len
(
predecessor_node_outputs
)
==
1
:
predecessor_name
=
predecessor
.
s
(
'name'
)
src_node_idx
=
None
# FIXME: exchange
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
# TODO: match subgraph with maintained graphs
# build cell
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
else
:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
node_index
[
node
]
=
subcell
# connect the cell into graph
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
else
:
else
:
raise
RuntimeError
(
'unsupported CallMethod {}'
.
format
(
node
.
s
(
'name'
)))
src_node_idx
=
predecessor_node_outputs
.
index
(
_output
)
elif
node
.
kind
()
==
'prim::CallFunction'
:
ir_graph
.
add_edge
(
head
=
(
node_index
[
_output
.
node
()],
src_node_idx
),
func_type_str
=
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
tail
=
(
ir_graph
.
output_node
,
None
))
func
=
node
.
inputsAt
(
0
).
node
()
assert
func
.
kind
()
==
'prim::Constant'
assert
func
.
hasAttribute
(
'name'
)
func_name
=
func
.
s
(
'name'
)
# create node for func
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
node_index
[
node
]
=
func_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
new_node
=
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
handle_prim_attr_node
(
node
)
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
global_seq
),
node_type
,
attrs
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::If'
:
last_block_node
=
handle_if_node
(
node
)
# last_block_node is None means no node in the branch block
node_index
[
node
]
=
last_block_node
elif
node
.
kind
()
==
'prim::Loop'
:
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise
RuntimeError
(
'Loop has not been supported yet!'
)
else
:
raise
RuntimeError
(
'Unsupported kind: {}'
.
format
(
node
.
kind
()))
return
node_index
[
node
]
for
node
in
sm_graph
.
nodes
():
handle_single_node
(
node
)
return
node_index
def
merge_aten_slices
(
ir_graph
):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes
=
[]
has_slice_node
=
False
for
node
in
ir_graph
.
hidden_nodes
:
if
node
.
operation
.
type
==
'aten::slice'
:
has_slice_node
=
True
for
pred
in
node
.
predecessors
:
if
pred
.
operation
.
type
not
in
[
'aten::slice'
,
'prim::Constant'
]:
head_slice_nodes
.
append
(
node
)
break
if
has_slice_node
:
assert
head_slice_nodes
for
head_node
in
head_slice_nodes
:
slot
=
0
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
if
len
(
head_node
.
incoming_edges
)
==
4
:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert
len
(
head_node
.
incoming_edges
)
==
5
for
edge
in
head_node
.
incoming_edges
:
edge
.
tail
=
new_slice_node
slot
+=
5
node
=
head_node
while
len
(
node
.
successors
)
==
1
and
node
.
successors
[
0
].
operation
.
type
==
'aten::slice'
:
suc_node
=
node
.
successors
[
0
]
assert
len
(
suc_node
.
incoming_edges
)
==
5
for
edge
in
suc_node
.
incoming_edges
:
if
edge
.
tail_slot
==
0
:
edge
.
remove
()
else
:
edge
.
tail
=
new_slice_node
edge
.
tail_slot
=
slot
+
edge
.
tail_slot
-
1
slot
+=
4
ir_graph
.
hidden_nodes
.
remove
(
node
)
node
=
suc_node
for
edge
in
node
.
outgoing_edges
:
self
.
refine_graph
(
ir_graph
)
edge
.
head
=
new_slice_node
ir_graph
.
hidden_nodes
.
remove
(
node
)
ir_graph
.
_register
()
def
refine_graph
(
ir_graph
):
return
ir_graph
,
{}
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::Constant'
)
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
module
):
global
modules_arg
m_attrs
=
{}
candidates
=
module
.
candidate_ops
choices
=
[]
for
cand
in
candidates
:
assert
id
(
cand
)
in
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
_handle_inputchoice
(
module
):
m_attrs
=
{}
m_attrs
[
'n_candidates'
]
=
module
.
n_candidates
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
convert_module
(
script_module
,
module
,
module_name
,
ir_model
):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
def
convert_to_graph
(
script_module
,
module
):
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
global
global_graph_id
global
modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
if
original_type_name
==
OpTypeName
.
LayerChoice
:
m_attrs
=
_handle_layerchoice
(
module
)
return
None
,
m_attrs
if
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
_handle_inputchoice
(
module
)
return
None
,
m_attrs
if
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
if
original_type_name
in
torch
.
nn
.
__dict__
and
original_type_name
not
in
MODULE_EXCEPT_LIST
:
# this is a basic module from pytorch, no need to parse its graph
assert
id
(
module
)
in
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
# handle TorchScript graph
sm_graph
=
script_module
.
graph
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
# handle graph nodes
node_index
=
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
)
# handle graph outputs
for
_output
in
sm_graph
.
outputs
():
ir_graph
.
_add_output
(
_convert_name
(
_output
.
debugName
()))
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
if
len
(
predecessor_node_outputs
)
==
1
:
src_node_idx
=
None
else
:
src_node_idx
=
predecessor_node_outputs
.
index
(
_output
)
ir_graph
.
add_edge
(
head
=
(
node_index
[
_output
.
node
()],
src_node_idx
),
tail
=
(
ir_graph
.
output_node
,
None
))
refine_graph
(
ir_graph
)
ir_graph
.
_register
()
if
id
(
module
)
not
in
modules_arg
:
raise
RuntimeError
(
f
'
{
original_type_name
}
arguments are not recorded,
\
you might have forgotten to decorate this class with @register_module()'
)
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return
ir_graph
,
{}
def
convert_to_graph
(
script_module
,
module
,
recorded_modules_arg
):
"""
"""
Convert module to our graph ir, i.e., build a ```Model``` type
Convert module to our graph ir, i.e., build a ```Model``` type
...
@@ -527,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
...
@@ -527,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
the script module obtained with torch.jit.script
the script module obtained with torch.jit.script
module : nn.Module
module : nn.Module
the targeted module instance
the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns
Returns
-------
Model
Model
the constructed IR model
the constructed IR model
"""
"""
global
modules_arg
modules_arg
=
recorded_modules_arg
model
=
Model
(
_internal
=
True
)
model
=
Model
(
_internal
=
True
)
module_name
=
'_model'
module_name
=
'_model'
convert_module
(
script_module
,
module
,
module_name
,
model
)
GraphConverter
().
convert_module
(
script_module
,
module
,
module_name
,
model
)
return
model
return
model
nni/retiarii/experiment.py
View file @
ae50ed14
...
@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__)
...
@@ -29,6 +29,7 @@ _logger = logging.getLogger(__name__)
OneShotTrainers
=
(
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
)
OneShotTrainers
=
(
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
)
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
class
RetiariiExeConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
experiment_name
:
Optional
[
str
]
=
None
...
@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment):
...
@@ -125,14 +126,17 @@ class RetiariiExperiment(Experiment):
except
Exception
as
e
:
except
Exception
as
e
:
_logger
.
error
(
'Your base model cannot be parsed by torch.jit.script, please fix the following error:'
)
_logger
.
error
(
'Your base model cannot be parsed by torch.jit.script, please fix the following error:'
)
raise
e
raise
e
base_model
=
convert_to_graph
(
script_module
,
self
.
base_model
,
self
.
recorded_module_args
)
base_model
_ir
=
convert_to_graph
(
script_module
,
self
.
base_model
)
assert
id
(
self
.
trainer
)
in
self
.
recorded_module_args
recorded_module_args
=
get_records
()
trainer_config
=
self
.
recorded_module_args
[
id
(
self
.
trainer
)]
if
id
(
self
.
trainer
)
not
in
recorded_module_args
:
base_model
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
raise
KeyError
(
'Your trainer is not found in registered classes. You might have forgotten to
\
register your customized trainer with @register_trainer decorator.'
)
trainer_config
=
recorded_module_args
[
id
(
self
.
trainer
)]
base_model_ir
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
# handle inline mutations
# handle inline mutations
mutators
=
self
.
_process_inline_mutation
(
base_model
)
mutators
=
self
.
_process_inline_mutation
(
base_model
_ir
)
if
mutators
is
not
None
and
self
.
applied_mutators
:
if
mutators
is
not
None
and
self
.
applied_mutators
:
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators,
\
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators,
\
do not use mutators when you use LayerChoice/InputChoice'
)
do not use mutators when you use LayerChoice/InputChoice'
)
...
@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment):
...
@@ -140,7 +144,7 @@ class RetiariiExperiment(Experiment):
self
.
applied_mutators
=
mutators
self
.
applied_mutators
=
mutators
_logger
.
info
(
'Starting strategy...'
)
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
)).
start
()
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
_ir
,
self
.
applied_mutators
)).
start
()
_logger
.
info
(
'Strategy started!'
)
_logger
.
info
(
'Strategy started!'
)
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
ae50ed14
import
inspect
import
logging
import
logging
from
typing
import
Any
,
List
from
typing
import
Any
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
...utils
import
add_record
,
version_larger_equal
from
...utils
import
add_record
,
blackbox_module
,
uid
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'):
...
@@ -40,16 +39,13 @@ if version_larger_equal(torch.__version__, '1.6.0'):
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'ChannelShuffle'
class
LayerChoice
(
nn
.
Module
):
class
LayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
op_candidates
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
def
__init__
(
self
,
op_candidates
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
super
(
LayerChoice
,
self
).
__init__
()
super
(
LayerChoice
,
self
).
__init__
()
self
.
candidate
_op
s
=
op_candidates
self
.
op_
candidates
=
op_candidates
self
.
label
=
key
self
.
label
=
key
if
key
is
not
None
else
f
'layerchoice_
{
uid
()
}
'
self
.
key
=
key
# deprecated, for backward compatibility
self
.
key
=
self
.
label
# deprecated, for backward compatibility
for
i
,
module
in
enumerate
(
op_candidates
):
# deprecated, for backward compatibility
for
i
,
module
in
enumerate
(
op_candidates
):
# deprecated, for backward compatibility
self
.
add_module
(
str
(
i
),
module
)
self
.
add_module
(
str
(
i
),
module
)
if
reduction
or
return_mask
:
if
reduction
or
return_mask
:
...
@@ -66,8 +62,8 @@ class InputChoice(nn.Module):
...
@@ -66,8 +62,8 @@ class InputChoice(nn.Module):
self
.
n_candidates
=
n_candidates
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
label
=
key
self
.
label
=
key
if
key
is
not
None
else
f
'inputchoice_
{
uid
()
}
'
self
.
key
=
key
# deprecated, for backward compatibility
self
.
key
=
self
.
label
# deprecated, for backward compatibility
if
choose_from
or
return_mask
:
if
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
...
@@ -101,6 +97,7 @@ class Placeholder(nn.Module):
...
@@ -101,6 +97,7 @@ class Placeholder(nn.Module):
class
ChosenInputs
(
nn
.
Module
):
class
ChosenInputs
(
nn
.
Module
):
"""
"""
"""
"""
def
__init__
(
self
,
chosen
:
List
[
int
],
reduction
:
str
):
def
__init__
(
self
,
chosen
:
List
[
int
],
reduction
:
str
):
super
().
__init__
()
super
().
__init__
()
self
.
chosen
=
chosen
self
.
chosen
=
chosen
...
@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module):
...
@@ -128,9 +125,7 @@ class ChosenInputs(nn.Module):
# the following are pytorch modules
# the following are pytorch modules
class
Module
(
nn
.
Module
):
Module
=
nn
.
Module
def
__init__
(
self
):
super
(
Module
,
self
).
__init__
()
class
Sequential
(
nn
.
Sequential
):
class
Sequential
(
nn
.
Sequential
):
...
@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList):
...
@@ -145,143 +140,116 @@ class ModuleList(nn.ModuleList):
super
(
ModuleList
,
self
).
__init__
(
*
args
)
super
(
ModuleList
,
self
).
__init__
(
*
args
)
def
wrap_module
(
original_class
):
Identity
=
blackbox_module
(
nn
.
Identity
)
orig_init
=
original_class
.
__init__
Linear
=
blackbox_module
(
nn
.
Linear
)
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
Conv1d
=
blackbox_module
(
nn
.
Conv1d
)
# Make copy of original __init__, so we can call it without recursion
Conv2d
=
blackbox_module
(
nn
.
Conv2d
)
Conv3d
=
blackbox_module
(
nn
.
Conv3d
)
def
__init__
(
self
,
*
args
,
**
kws
):
ConvTranspose1d
=
blackbox_module
(
nn
.
ConvTranspose1d
)
full_args
=
{}
ConvTranspose2d
=
blackbox_module
(
nn
.
ConvTranspose2d
)
full_args
.
update
(
kws
)
ConvTranspose3d
=
blackbox_module
(
nn
.
ConvTranspose3d
)
for
i
,
arg
in
enumerate
(
args
):
Threshold
=
blackbox_module
(
nn
.
Threshold
)
full_args
[
argname_list
[
i
]]
=
arg
ReLU
=
blackbox_module
(
nn
.
ReLU
)
add_record
(
id
(
self
),
full_args
)
Hardtanh
=
blackbox_module
(
nn
.
Hardtanh
)
ReLU6
=
blackbox_module
(
nn
.
ReLU6
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
Sigmoid
=
blackbox_module
(
nn
.
Sigmoid
)
Tanh
=
blackbox_module
(
nn
.
Tanh
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
Softmax
=
blackbox_module
(
nn
.
Softmax
)
return
original_class
Softmax2d
=
blackbox_module
(
nn
.
Softmax2d
)
LogSoftmax
=
blackbox_module
(
nn
.
LogSoftmax
)
ELU
=
blackbox_module
(
nn
.
ELU
)
Identity
=
wrap_module
(
nn
.
Identity
)
SELU
=
blackbox_module
(
nn
.
SELU
)
Linear
=
wrap_module
(
nn
.
Linear
)
CELU
=
blackbox_module
(
nn
.
CELU
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
GLU
=
blackbox_module
(
nn
.
GLU
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
GELU
=
blackbox_module
(
nn
.
GELU
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
Hardshrink
=
blackbox_module
(
nn
.
Hardshrink
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
LeakyReLU
=
blackbox_module
(
nn
.
LeakyReLU
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
LogSigmoid
=
blackbox_module
(
nn
.
LogSigmoid
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Softplus
=
blackbox_module
(
nn
.
Softplus
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
Softshrink
=
blackbox_module
(
nn
.
Softshrink
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
MultiheadAttention
=
blackbox_module
(
nn
.
MultiheadAttention
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
PReLU
=
blackbox_module
(
nn
.
PReLU
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
Softsign
=
blackbox_module
(
nn
.
Softsign
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Softmin
=
blackbox_module
(
nn
.
Softmin
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Tanhshrink
=
blackbox_module
(
nn
.
Tanhshrink
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
RReLU
=
blackbox_module
(
nn
.
RReLU
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
AvgPool1d
=
blackbox_module
(
nn
.
AvgPool1d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
AvgPool2d
=
blackbox_module
(
nn
.
AvgPool2d
)
ELU
=
wrap_module
(
nn
.
ELU
)
AvgPool3d
=
blackbox_module
(
nn
.
AvgPool3d
)
SELU
=
wrap_module
(
nn
.
SELU
)
MaxPool1d
=
blackbox_module
(
nn
.
MaxPool1d
)
CELU
=
wrap_module
(
nn
.
CELU
)
MaxPool2d
=
blackbox_module
(
nn
.
MaxPool2d
)
GLU
=
wrap_module
(
nn
.
GLU
)
MaxPool3d
=
blackbox_module
(
nn
.
MaxPool3d
)
GELU
=
wrap_module
(
nn
.
GELU
)
MaxUnpool1d
=
blackbox_module
(
nn
.
MaxUnpool1d
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
MaxUnpool2d
=
blackbox_module
(
nn
.
MaxUnpool2d
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
MaxUnpool3d
=
blackbox_module
(
nn
.
MaxUnpool3d
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
FractionalMaxPool2d
=
blackbox_module
(
nn
.
FractionalMaxPool2d
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
FractionalMaxPool3d
=
blackbox_module
(
nn
.
FractionalMaxPool3d
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
LPPool1d
=
blackbox_module
(
nn
.
LPPool1d
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
LPPool2d
=
blackbox_module
(
nn
.
LPPool2d
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
LocalResponseNorm
=
blackbox_module
(
nn
.
LocalResponseNorm
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
BatchNorm1d
=
blackbox_module
(
nn
.
BatchNorm1d
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
BatchNorm2d
=
blackbox_module
(
nn
.
BatchNorm2d
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
BatchNorm3d
=
blackbox_module
(
nn
.
BatchNorm3d
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
InstanceNorm1d
=
blackbox_module
(
nn
.
InstanceNorm1d
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
InstanceNorm2d
=
blackbox_module
(
nn
.
InstanceNorm2d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
InstanceNorm3d
=
blackbox_module
(
nn
.
InstanceNorm3d
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
LayerNorm
=
blackbox_module
(
nn
.
LayerNorm
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
GroupNorm
=
blackbox_module
(
nn
.
GroupNorm
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
SyncBatchNorm
=
blackbox_module
(
nn
.
SyncBatchNorm
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
Dropout
=
blackbox_module
(
nn
.
Dropout
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
Dropout2d
=
blackbox_module
(
nn
.
Dropout2d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
Dropout3d
=
blackbox_module
(
nn
.
Dropout3d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
AlphaDropout
=
blackbox_module
(
nn
.
AlphaDropout
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FeatureAlphaDropout
=
blackbox_module
(
nn
.
FeatureAlphaDropout
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
ReflectionPad1d
=
blackbox_module
(
nn
.
ReflectionPad1d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
ReflectionPad2d
=
blackbox_module
(
nn
.
ReflectionPad2d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
ReplicationPad2d
=
blackbox_module
(
nn
.
ReplicationPad2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
ReplicationPad1d
=
blackbox_module
(
nn
.
ReplicationPad1d
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
ReplicationPad3d
=
blackbox_module
(
nn
.
ReplicationPad3d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
CrossMapLRN2d
=
blackbox_module
(
nn
.
CrossMapLRN2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
Embedding
=
blackbox_module
(
nn
.
Embedding
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
EmbeddingBag
=
blackbox_module
(
nn
.
EmbeddingBag
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
RNNBase
=
blackbox_module
(
nn
.
RNNBase
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
RNN
=
blackbox_module
(
nn
.
RNN
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
LSTM
=
blackbox_module
(
nn
.
LSTM
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
GRU
=
blackbox_module
(
nn
.
GRU
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
RNNCellBase
=
blackbox_module
(
nn
.
RNNCellBase
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
RNNCell
=
blackbox_module
(
nn
.
RNNCell
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
LSTMCell
=
blackbox_module
(
nn
.
LSTMCell
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
GRUCell
=
blackbox_module
(
nn
.
GRUCell
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
PixelShuffle
=
blackbox_module
(
nn
.
PixelShuffle
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
Upsample
=
blackbox_module
(
nn
.
Upsample
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
UpsamplingNearest2d
=
blackbox_module
(
nn
.
UpsamplingNearest2d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
UpsamplingBilinear2d
=
blackbox_module
(
nn
.
UpsamplingBilinear2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
PairwiseDistance
=
blackbox_module
(
nn
.
PairwiseDistance
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
AdaptiveMaxPool1d
=
blackbox_module
(
nn
.
AdaptiveMaxPool1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
AdaptiveMaxPool2d
=
blackbox_module
(
nn
.
AdaptiveMaxPool2d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
AdaptiveMaxPool3d
=
blackbox_module
(
nn
.
AdaptiveMaxPool3d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
AdaptiveAvgPool1d
=
blackbox_module
(
nn
.
AdaptiveAvgPool1d
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
AdaptiveAvgPool2d
=
blackbox_module
(
nn
.
AdaptiveAvgPool2d
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
AdaptiveAvgPool3d
=
blackbox_module
(
nn
.
AdaptiveAvgPool3d
)
RNN
=
wrap_module
(
nn
.
RNN
)
TripletMarginLoss
=
blackbox_module
(
nn
.
TripletMarginLoss
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
ZeroPad2d
=
blackbox_module
(
nn
.
ZeroPad2d
)
GRU
=
wrap_module
(
nn
.
GRU
)
ConstantPad1d
=
blackbox_module
(
nn
.
ConstantPad1d
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
ConstantPad2d
=
blackbox_module
(
nn
.
ConstantPad2d
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
ConstantPad3d
=
blackbox_module
(
nn
.
ConstantPad3d
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
Bilinear
=
blackbox_module
(
nn
.
Bilinear
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
CosineSimilarity
=
blackbox_module
(
nn
.
CosineSimilarity
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Unfold
=
blackbox_module
(
nn
.
Unfold
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
Fold
=
blackbox_module
(
nn
.
Fold
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
AdaptiveLogSoftmaxWithLoss
=
blackbox_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
TransformerEncoder
=
blackbox_module
(
nn
.
TransformerEncoder
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
TransformerDecoder
=
blackbox_module
(
nn
.
TransformerDecoder
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
TransformerEncoderLayer
=
blackbox_module
(
nn
.
TransformerEncoderLayer
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
TransformerDecoderLayer
=
blackbox_module
(
nn
.
TransformerDecoderLayer
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
Transformer
=
blackbox_module
(
nn
.
Transformer
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
Flatten
=
blackbox_module
(
nn
.
Flatten
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
Hardsigmoid
=
blackbox_module
(
nn
.
Hardsigmoid
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
Flatten
=
wrap_module
(
nn
.
Flatten
)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
wrap
_module
(
nn
.
Hardswish
)
Hardswish
=
blackbox
_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
wrap_module
(
nn
.
SiLU
)
SiLU
=
blackbox_module
(
nn
.
SiLU
)
Unflatten
=
wrap_module
(
nn
.
Unflatten
)
Unflatten
=
blackbox_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
wrap_module
(
nn
.
TripletMarginWithDistanceLoss
)
TripletMarginWithDistanceLoss
=
blackbox_module
(
nn
.
TripletMarginWithDistanceLoss
)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
\ No newline at end of file
nni/retiarii/trainer/pytorch/base.py
View file @
ae50ed14
...
@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
...
@@ -43,7 +43,7 @@ def get_default_transform(dataset: str) -> Any:
return
None
return
None
@
register_trainer
()
@
register_trainer
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
"""
"""
Image classification trainer for PyTorch.
Image classification trainer for PyTorch.
...
@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -80,7 +80,7 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
only the key ``max_epochs`` is useful.
"""
"""
super
(
PyTorchImageClassificationTrainer
,
self
).
__init__
()
super
().
__init__
()
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
model
=
model
self
.
model
=
model
if
self
.
_use_cuda
:
if
self
.
_use_cuda
:
...
...
nni/retiarii/utils.py
View file @
ae50ed14
import
inspect
import
inspect
import
warnings
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
from
typing
import
Any
...
@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any:
...
@@ -10,12 +11,14 @@ def import_(target: str, allow_none: bool = False) -> Any:
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
return
getattr
(
module
,
identifier
)
def
version_larger_equal
(
a
:
str
,
b
:
str
)
->
bool
:
def
version_larger_equal
(
a
:
str
,
b
:
str
)
->
bool
:
# TODO: refactor later
# TODO: refactor later
a
=
a
.
split
(
'+'
)[
0
]
a
=
a
.
split
(
'+'
)[
0
]
b
=
b
.
split
(
'+'
)[
0
]
b
=
b
.
split
(
'+'
)[
0
]
return
tuple
(
map
(
int
,
a
.
split
(
'.'
)))
>=
tuple
(
map
(
int
,
b
.
split
(
'.'
)))
return
tuple
(
map
(
int
,
a
.
split
(
'.'
)))
>=
tuple
(
map
(
int
,
b
.
split
(
'.'
)))
_records
=
{}
_records
=
{}
...
@@ -29,73 +32,87 @@ def add_record(key, value):
...
@@ -29,73 +32,87 @@ def add_record(key, value):
"""
"""
global
_records
global
_records
if
_records
is
not
None
:
if
_records
is
not
None
:
#
assert key not in _records, '{} already in _records'.format(key)
assert
key
not
in
_records
,
'{} already in _records'
.
format
(
key
)
_records
[
key
]
=
value
_records
[
key
]
=
value
def
_register_module
(
original_class
):
def
del_record
(
key
):
orig_init
=
original_class
.
__init__
global
_records
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
if
_records
is
not
None
:
# Make copy of original __init__, so we can call it without recursion
_records
.
pop
(
key
,
None
)
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
_blackbox_cls
(
cls
,
module_name
,
register_format
=
None
):
class
wrapper
(
cls
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
argname_list
=
list
(
inspect
.
signature
(
cls
).
parameters
.
keys
())
full_args
=
{}
full_args
.
update
(
kwargs
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
assert
len
(
args
)
<=
len
(
argname_list
),
f
'Length of
{
args
}
is greater than length of
{
argname_list
}
.'
return
original_class
for
argname
,
value
in
zip
(
argname_list
,
args
):
full_args
[
argname
]
=
value
# eject un-serializable arguments
for
k
in
list
(
full_args
.
keys
()):
# The list is not complete and does not support nested cases.
if
not
isinstance
(
full_args
[
k
],
(
int
,
float
,
str
,
dict
,
list
)):
if
not
(
register_format
==
'full'
and
k
==
'model'
):
# no warning if it is base model in trainer
warnings
.
warn
(
f
'
{
cls
}
has un-serializable arguments
{
k
}
whose value is
{
full_args
[
k
]
}
.
\
This is not supported. You can ignore this warning if you are passing the model to trainer.'
)
full_args
.
pop
(
k
)
def
register_module
():
if
register_format
==
'args'
:
"""
add_record
(
id
(
self
),
full_args
)
Register a module.
elif
register_format
==
'full'
:
"""
full_class_name
=
cls
.
__module__
+
'.'
+
cls
.
__name__
# use it as a decorator: @register_module()
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
def
_register
(
cls
):
m
=
_register_module
(
original_class
=
cls
)
return
m
return
_register
super
().
__init__
(
*
args
,
**
kwargs
)
def
__del__
(
self
):
del_record
(
id
(
self
))
def
_register_trainer
(
original_class
):
# using module_name instead of cls.__module__ because it's more natural to see where the module gets wrapped
orig_init
=
original_class
.
__init__
# instead of simply putting torch.nn or etc.
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
wrapper
.
__module__
=
module_name
# Make copy of original __init__, so we can call it without recursion
wrapper
.
__name__
=
cls
.
__name__
wrapper
.
__qualname__
=
cls
.
__qualname__
wrapper
.
__init__
.
__doc__
=
cls
.
__init__
.
__doc__
full_class_name
=
original_class
.
__module__
+
'.'
+
original_class
.
__name__
return
wrapper
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
# TODO: support both pytorch and tensorflow
from
.nn.pytorch
import
Module
if
isinstance
(
args
[
i
],
Module
):
# ignore the base model object
continue
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
def
blackbox
(
cls
,
*
args
,
**
kwargs
):
"""
To create an blackbox instance inline without decorator. For example,
.. code-block:: python
self.op = blackbox(MyCustomOp, hidden_units=128)
"""
# get caller module name
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'args'
)(
*
args
,
**
kwargs
)
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
def
blackbox_module
(
cls
):
"""
Register a module. Use it as a decorator.
"""
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'args'
)
def
register_trainer
():
def
_register
(
cls
):
m
=
_register_trainer
(
original_class
=
cls
)
return
m
return
_register
def
register_trainer
(
cls
):
"""
Register a trainer. Use it as a decorator.
"""
frm
=
inspect
.
stack
()[
1
]
module_name
=
inspect
.
getmodule
(
frm
[
0
]).
__name__
return
_blackbox_cls
(
cls
,
module_name
,
'full'
)
_last_uid
=
defaultdict
(
int
)
_last_uid
=
defaultdict
(
int
)
...
...
test/.gitignore
View file @
ae50ed14
...
@@ -5,6 +5,7 @@ tuner_result.txt
...
@@ -5,6 +5,7 @@ tuner_result.txt
assessor_result.txt
assessor_result.txt
_generated_model.py
_generated_model.py
_generated_model_*.py
data
data
generated
generated
test/retiarii_test/darts/darts_model.py
View file @
ae50ed14
...
@@ -7,9 +7,9 @@ import torch.nn as torch_nn
...
@@ -7,9 +7,9 @@ import torch.nn as torch_nn
import
ops
import
ops
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register_module
from
nni.retiarii
import
blackbox_module
@
blackbox_module
class
AuxiliaryHead
(
nn
.
Module
):
class
AuxiliaryHead
(
nn
.
Module
):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
...
@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module):
...
@@ -35,7 +35,6 @@ class AuxiliaryHead(nn.Module):
logits
=
self
.
linear
(
out
)
logits
=
self
.
linear
(
out
)
return
logits
return
logits
@
register_module
()
class
Node
(
nn
.
Module
):
class
Node
(
nn
.
Module
):
def
__init__
(
self
,
node_id
,
num_prev_nodes
,
channels
,
num_downsample_connect
):
def
__init__
(
self
,
node_id
,
num_prev_nodes
,
channels
,
num_downsample_connect
):
super
().
__init__
()
super
().
__init__
()
...
@@ -66,7 +65,6 @@ class Node(nn.Module):
...
@@ -66,7 +65,6 @@ class Node(nn.Module):
#out = [self.drop_path(o) if o is not None else None for o in out]
#out = [self.drop_path(o) if o is not None else None for o in out]
return
self
.
input_switch
(
out
)
return
self
.
input_switch
(
out
)
@
register_module
()
class
Cell
(
nn
.
Module
):
class
Cell
(
nn
.
Module
):
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
...
@@ -100,7 +98,6 @@ class Cell(nn.Module):
...
@@ -100,7 +98,6 @@ class Cell(nn.Module):
output
=
torch
.
cat
(
new_tensors
,
dim
=
1
)
output
=
torch
.
cat
(
new_tensors
,
dim
=
1
)
return
output
return
output
@
register_module
()
class
CNN
(
nn
.
Module
):
class
CNN
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
...
...
test/retiarii_test/darts/ops.py
View file @
ae50ed14
import
torch
import
torch
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register
_module
from
nni.retiarii
import
blackbox
_module
@
register
_module
()
@
blackbox
_module
class
DropPath
(
nn
.
Module
):
class
DropPath
(
nn
.
Module
):
def
__init__
(
self
,
p
=
0.
):
def
__init__
(
self
,
p
=
0.
):
"""
"""
...
@@ -12,7 +12,7 @@ class DropPath(nn.Module):
...
@@ -12,7 +12,7 @@ class DropPath(nn.Module):
p : float
p : float
Probability of an path to be zeroed.
Probability of an path to be zeroed.
"""
"""
super
(
DropPath
,
self
).
__init__
()
super
().
__init__
()
self
.
p
=
p
self
.
p
=
p
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -24,13 +24,13 @@ class DropPath(nn.Module):
...
@@ -24,13 +24,13 @@ class DropPath(nn.Module):
return
x
return
x
@
register
_module
()
@
blackbox
_module
class
PoolBN
(
nn
.
Module
):
class
PoolBN
(
nn
.
Module
):
"""
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
"""
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
PoolBN
,
self
).
__init__
()
super
().
__init__
()
if
pool_type
.
lower
()
==
'max'
:
if
pool_type
.
lower
()
==
'max'
:
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
,
stride
,
padding
)
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
,
stride
,
padding
)
elif
pool_type
.
lower
()
==
'avg'
:
elif
pool_type
.
lower
()
==
'avg'
:
...
@@ -45,13 +45,13 @@ class PoolBN(nn.Module):
...
@@ -45,13 +45,13 @@ class PoolBN(nn.Module):
out
=
self
.
bn
(
out
)
out
=
self
.
bn
(
out
)
return
out
return
out
@
register
_module
()
@
blackbox
_module
class
StdConv
(
nn
.
Module
):
class
StdConv
(
nn
.
Module
):
"""
"""
Standard conv: ReLU - Conv - BN
Standard conv: ReLU - Conv - BN
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
StdConv
,
self
).
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
bias
=
False
),
...
@@ -61,13 +61,13 @@ class StdConv(nn.Module):
...
@@ -61,13 +61,13 @@ class StdConv(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
@
register
_module
()
@
blackbox
_module
class
FacConv
(
nn
.
Module
):
class
FacConv
(
nn
.
Module
):
"""
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
super
(
FacConv
,
self
).
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
(
kernel_length
,
1
),
stride
,
padding
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_in
,
(
kernel_length
,
1
),
stride
,
padding
,
bias
=
False
),
...
@@ -78,7 +78,7 @@ class FacConv(nn.Module):
...
@@ -78,7 +78,7 @@ class FacConv(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
@
register
_module
()
@
blackbox
_module
class
DilConv
(
nn
.
Module
):
class
DilConv
(
nn
.
Module
):
"""
"""
(Dilated) depthwise separable conv.
(Dilated) depthwise separable conv.
...
@@ -86,7 +86,7 @@ class DilConv(nn.Module):
...
@@ -86,7 +86,7 @@ class DilConv(nn.Module):
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
super
(
DilConv
,
self
).
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
...
@@ -98,14 +98,14 @@ class DilConv(nn.Module):
...
@@ -98,14 +98,14 @@ class DilConv(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
@
register
_module
()
@
blackbox
_module
class
SepConv
(
nn
.
Module
):
class
SepConv
(
nn
.
Module
):
"""
"""
Depthwise separable conv.
Depthwise separable conv.
DilConv(dilation=1) * 2.
DilConv(dilation=1) * 2.
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
SepConv
,
self
).
__init__
()
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
self
.
net
=
nn
.
Sequential
(
DilConv
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
1
,
affine
=
affine
),
DilConv
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
1
,
affine
=
affine
),
DilConv
(
C_in
,
C_out
,
kernel_size
,
1
,
padding
,
dilation
=
1
,
affine
=
affine
)
DilConv
(
C_in
,
C_out
,
kernel_size
,
1
,
padding
,
dilation
=
1
,
affine
=
affine
)
...
@@ -114,13 +114,13 @@ class SepConv(nn.Module):
...
@@ -114,13 +114,13 @@ class SepConv(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
return
self
.
net
(
x
)
@
register
_module
()
@
blackbox
_module
class
FactorizedReduce
(
nn
.
Module
):
class
FactorizedReduce
(
nn
.
Module
):
"""
"""
Reduce feature map size by factorized pointwise (stride=2).
Reduce feature map size by factorized pointwise (stride=2).
"""
"""
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
super
(
FactorizedReduce
,
self
).
__init__
()
super
().
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
relu
=
nn
.
ReLU
()
self
.
conv1
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
self
.
conv1
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
...
...
test/retiarii_test/darts/test.py
View file @
ae50ed14
...
@@ -13,10 +13,10 @@ from darts_model import CNN
...
@@ -13,10 +13,10 @@ from darts_model import CNN
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
base_model
=
CNN
(
32
,
3
,
16
,
10
,
8
)
base_model
=
CNN
(
32
,
3
,
16
,
10
,
8
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
dataset_kwargs
=
{
"root"
:
"data/cifar10"
,
"download"
:
True
},
dataset_kwargs
=
{
"root"
:
"data/cifar10"
,
"download"
:
True
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
trainer_kwargs
=
{
"max_epochs"
:
1
})
#simple_startegy = TPEStrategy()
#simple_startegy = TPEStrategy()
simple_startegy
=
RandomStrategy
()
simple_startegy
=
RandomStrategy
()
...
@@ -31,4 +31,4 @@ if __name__ == '__main__':
...
@@ -31,4 +31,4 @@ if __name__ == '__main__':
exp_config
.
training_service
.
use_active_gpu
=
True
exp_config
.
training_service
.
use_active_gpu
=
True
exp_config
.
training_service
.
gpu_indices
=
[
1
,
2
]
exp_config
.
training_service
.
gpu_indices
=
[
1
,
2
]
exp
.
run
(
exp_config
,
8081
,
debug
=
True
)
exp
.
run
(
exp_config
,
8081
)
test/retiarii_test/darts/test_oneshot.py
View file @
ae50ed14
...
@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0):
...
@@ -56,8 +56,8 @@ def get_dataset(cls, cutout_length=0):
valid_transform
=
transforms
.
Compose
(
normalize
)
valid_transform
=
transforms
.
Compose
(
normalize
)
if
cls
==
"cifar10"
:
if
cls
==
"cifar10"
:
dataset_train
=
CIFAR10
(
root
=
"./data"
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
dataset_train
=
CIFAR10
(
root
=
"./data
/cifar10
"
,
train
=
True
,
download
=
True
,
transform
=
train_transform
)
dataset_valid
=
CIFAR10
(
root
=
"./data"
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
dataset_valid
=
CIFAR10
(
root
=
"./data
/cifar10
"
,
train
=
False
,
download
=
True
,
transform
=
valid_transform
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
return
dataset_train
,
dataset_valid
return
dataset_train
,
dataset_valid
...
...
test/retiarii_test/mnasnet/base_mnasnet.py
View file @
ae50ed14
from
nni.retiarii
import
blackbox_module
import
nni.retiarii.nn.pytorch
as
nn
import
warnings
import
warnings
import
torch
import
torch
...
@@ -8,8 +10,6 @@ import torch.nn.functional as F
...
@@ -8,8 +10,6 @@ import torch.nn.functional as F
import
sys
import
sys
from
pathlib
import
Path
from
pathlib
import
Path
sys
.
path
.
append
(
str
(
Path
(
__file__
).
resolve
().
parents
[
2
]))
sys
.
path
.
append
(
str
(
Path
(
__file__
).
resolve
().
parents
[
2
]))
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
# 1.0 - tensorflow.
...
@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module):
...
@@ -27,6 +27,7 @@ class _ResidualBlock(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
+
x
return
self
.
net
(
x
)
+
x
class
_InvertedResidual
(
nn
.
Module
):
class
_InvertedResidual
(
nn
.
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
,
kernel_size
,
stride
,
expansion_factor
,
skip
,
bn_momentum
=
0.1
):
def
__init__
(
self
,
in_ch
,
out_ch
,
kernel_size
,
stride
,
expansion_factor
,
skip
,
bn_momentum
=
0.1
):
...
@@ -110,7 +111,7 @@ def _get_depths(depths, alpha):
...
@@ -110,7 +111,7 @@ def _get_depths(depths, alpha):
rather than down. """
rather than down. """
return
[
_round_to_multiple_of
(
depth
*
alpha
,
8
)
for
depth
in
depths
]
return
[
_round_to_multiple_of
(
depth
*
alpha
,
8
)
for
depth
in
depths
]
@
register_module
()
class
MNASNet
(
nn
.
Module
):
class
MNASNet
(
nn
.
Module
):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
implements the B1 variant of the model.
...
@@ -127,7 +128,7 @@ class MNASNet(nn.Module):
...
@@ -127,7 +128,7 @@ class MNASNet(nn.Module):
def
__init__
(
self
,
alpha
,
depths
,
convops
,
kernel_sizes
,
num_layers
,
def
__init__
(
self
,
alpha
,
depths
,
convops
,
kernel_sizes
,
num_layers
,
skips
,
num_classes
=
1000
,
dropout
=
0.2
):
skips
,
num_classes
=
1000
,
dropout
=
0.2
):
super
(
MNASNet
,
self
).
__init__
()
super
().
__init__
()
assert
alpha
>
0.0
assert
alpha
>
0.0
assert
len
(
depths
)
==
len
(
convops
)
==
len
(
kernel_sizes
)
==
len
(
num_layers
)
==
len
(
skips
)
==
7
assert
len
(
depths
)
==
len
(
convops
)
==
len
(
kernel_sizes
)
==
len
(
num_layers
)
==
len
(
skips
)
==
7
self
.
alpha
=
alpha
self
.
alpha
=
alpha
...
@@ -143,22 +144,22 @@ class MNASNet(nn.Module):
...
@@ -143,22 +144,22 @@ class MNASNet(nn.Module):
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
]
]
count
=
0
count
=
0
#for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
#
for conv, prev_depth, depth, ks, skip, stride, repeat, exp_ratio in \
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
# zip(convops, depths[:-1], depths[1:], kernel_sizes, skips, strides, num_layers, exp_ratios):
for
filter_size
,
exp_ratio
,
stride
in
zip
(
base_filter_sizes
,
exp_ratios
,
strides
):
for
filter_size
,
exp_ratio
,
stride
in
zip
(
base_filter_sizes
,
exp_ratios
,
strides
):
# TODO: restrict that "choose" can only be used within mutator
# TODO: restrict that "choose" can only be used within mutator
ph
=
nn
.
Placeholder
(
label
=
f
'mutable_
{
count
}
'
,
related_info
=
{
ph
=
nn
.
Placeholder
(
label
=
f
'mutable_
{
count
}
'
,
related_info
=
{
'kernel_size_options'
:
[
1
,
3
,
5
],
'kernel_size_options'
:
[
1
,
3
,
5
],
'n_layer_options'
:
[
1
,
2
,
3
,
4
],
'n_layer_options'
:
[
1
,
2
,
3
,
4
],
'op_type_options'
:
[
'__mutated__.base_mnasnet.RegularConv'
,
'op_type_options'
:
[
'__mutated__.base_mnasnet.RegularConv'
,
'__mutated__.base_mnasnet.DepthwiseConv'
,
'__mutated__.base_mnasnet.DepthwiseConv'
,
'__mutated__.base_mnasnet.MobileConv'
],
'__mutated__.base_mnasnet.MobileConv'
],
#'se_ratio_options': [0, 0.25],
#
'se_ratio_options': [0, 0.25],
'skip_options'
:
[
'identity'
,
'no'
],
'skip_options'
:
[
'identity'
,
'no'
],
'n_filter_options'
:
[
int
(
filter_size
*
x
)
for
x
in
[
0.75
,
1.0
,
1.25
]],
'n_filter_options'
:
[
int
(
filter_size
*
x
)
for
x
in
[
0.75
,
1.0
,
1.25
]],
'exp_ratio'
:
exp_ratio
,
'exp_ratio'
:
exp_ratio
,
'stride'
:
stride
,
'stride'
:
stride
,
'in_ch'
:
depths
[
0
]
if
count
==
0
else
None
'in_ch'
:
depths
[
0
]
if
count
==
0
else
None
})
})
layers
.
append
(
ph
)
layers
.
append
(
ph
)
'''if conv == "mconv":
'''if conv == "mconv":
...
@@ -185,7 +186,7 @@ class MNASNet(nn.Module):
...
@@ -185,7 +186,7 @@ class MNASNet(nn.Module):
#self.for_test = 10
#self.for_test = 10
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
#if self.for_test == 10:
#
if self.for_test == 10:
x
=
self
.
layers
(
x
)
x
=
self
.
layers
(
x
)
# Equivalent to global avgpool and removing H and W dimensions.
# Equivalent to global avgpool and removing H and W dimensions.
x
=
x
.
mean
([
2
,
3
])
x
=
x
.
mean
([
2
,
3
])
...
@@ -196,7 +197,7 @@ class MNASNet(nn.Module):
...
@@ -196,7 +197,7 @@ class MNASNet(nn.Module):
for
m
in
self
.
modules
():
for
m
in
self
.
modules
():
if
isinstance
(
m
,
nn
.
Conv2d
):
if
isinstance
(
m
,
nn
.
Conv2d
):
torch_nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
"fan_out"
,
torch_nn
.
init
.
kaiming_normal_
(
m
.
weight
,
mode
=
"fan_out"
,
nonlinearity
=
"relu"
)
nonlinearity
=
"relu"
)
if
m
.
bias
is
not
None
:
if
m
.
bias
is
not
None
:
torch_nn
.
init
.
zeros_
(
m
.
bias
)
torch_nn
.
init
.
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
elif
isinstance
(
m
,
nn
.
BatchNorm2d
):
...
@@ -204,16 +205,18 @@ class MNASNet(nn.Module):
...
@@ -204,16 +205,18 @@ class MNASNet(nn.Module):
torch_nn
.
init
.
zeros_
(
m
.
bias
)
torch_nn
.
init
.
zeros_
(
m
.
bias
)
elif
isinstance
(
m
,
nn
.
Linear
):
elif
isinstance
(
m
,
nn
.
Linear
):
torch_nn
.
init
.
kaiming_uniform_
(
m
.
weight
,
mode
=
"fan_out"
,
torch_nn
.
init
.
kaiming_uniform_
(
m
.
weight
,
mode
=
"fan_out"
,
nonlinearity
=
"sigmoid"
)
nonlinearity
=
"sigmoid"
)
torch_nn
.
init
.
zeros_
(
m
.
bias
)
torch_nn
.
init
.
zeros_
(
m
.
bias
)
def
test_model
(
model
):
def
test_model
(
model
):
model
(
torch
.
randn
(
2
,
3
,
224
,
224
))
model
(
torch
.
randn
(
2
,
3
,
224
,
224
))
#====================definition of candidate op classes
# ====================definition of candidate op classes
BN_MOMENTUM
=
1
-
0.9997
BN_MOMENTUM
=
1
-
0.9997
class
RegularConv
(
nn
.
Module
):
class
RegularConv
(
nn
.
Module
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
super
().
__init__
()
super
().
__init__
()
...
@@ -234,6 +237,7 @@ class RegularConv(nn.Module):
...
@@ -234,6 +237,7 @@ class RegularConv(nn.Module):
out
=
out
+
x
out
=
out
+
x
return
out
return
out
class
DepthwiseConv
(
nn
.
Module
):
class
DepthwiseConv
(
nn
.
Module
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
super
().
__init__
()
super
().
__init__
()
...
@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module):
...
@@ -257,6 +261,7 @@ class DepthwiseConv(nn.Module):
out
=
out
+
x
out
=
out
+
x
return
out
return
out
class
MobileConv
(
nn
.
Module
):
class
MobileConv
(
nn
.
Module
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
def
__init__
(
self
,
kernel_size
,
in_ch
,
out_ch
,
skip
,
exp_ratio
,
stride
):
super
().
__init__
()
super
().
__init__
()
...
@@ -274,7 +279,7 @@ class MobileConv(nn.Module):
...
@@ -274,7 +279,7 @@ class MobileConv(nn.Module):
nn
.
BatchNorm2d
(
mid_ch
,
momentum
=
BN_MOMENTUM
),
nn
.
BatchNorm2d
(
mid_ch
,
momentum
=
BN_MOMENTUM
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
# Depthwise
# Depthwise
nn
.
Conv2d
(
mid_ch
,
mid_ch
,
kernel_size
,
padding
=
(
kernel_size
-
1
)
//
2
,
nn
.
Conv2d
(
mid_ch
,
mid_ch
,
kernel_size
,
padding
=
(
kernel_size
-
1
)
//
2
,
stride
=
stride
,
groups
=
mid_ch
,
bias
=
False
),
stride
=
stride
,
groups
=
mid_ch
,
bias
=
False
),
nn
.
BatchNorm2d
(
mid_ch
,
momentum
=
BN_MOMENTUM
),
nn
.
BatchNorm2d
(
mid_ch
,
momentum
=
BN_MOMENTUM
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
...
@@ -288,5 +293,6 @@ class MobileConv(nn.Module):
...
@@ -288,5 +293,6 @@ class MobileConv(nn.Module):
out
=
out
+
x
out
=
out
+
x
return
out
return
out
# mnasnet0_5
# mnasnet0_5
ir_module
=
_InvertedResidual
(
16
,
16
,
3
,
1
,
1
,
True
)
ir_module
=
_InvertedResidual
(
16
,
16
,
3
,
1
,
1
,
True
)
\ No newline at end of file
test/retiarii_test/mnasnet/test.py
View file @
ae50ed14
...
@@ -19,12 +19,12 @@ if __name__ == '__main__':
...
@@ -19,12 +19,12 @@ if __name__ == '__main__':
_DEFAULT_NUM_LAYERS
=
[
1
,
3
,
3
,
3
,
2
,
4
,
1
]
_DEFAULT_NUM_LAYERS
=
[
1
,
3
,
3
,
3
,
2
,
4
,
1
]
base_model
=
MNASNet
(
0.5
,
_DEFAULT_DEPTHS
,
_DEFAULT_CONVOPS
,
_DEFAULT_KERNEL_SIZES
,
base_model
=
MNASNet
(
0.5
,
_DEFAULT_DEPTHS
,
_DEFAULT_CONVOPS
,
_DEFAULT_KERNEL_SIZES
,
_DEFAULT_NUM_LAYERS
,
_DEFAULT_SKIPS
)
_DEFAULT_NUM_LAYERS
,
_DEFAULT_SKIPS
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
dataset_kwargs
=
{
"root"
:
"data/cifar10"
,
"download"
:
True
},
dataset_kwargs
=
{
"root"
:
"data/cifar10"
,
"download"
:
True
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
trainer_kwargs
=
{
"max_epochs"
:
1
})
# new interface
# new interface
applied_mutators
=
[]
applied_mutators
=
[]
...
@@ -41,4 +41,4 @@ if __name__ == '__main__':
...
@@ -41,4 +41,4 @@ if __name__ == '__main__':
exp_config
.
max_trial_number
=
10
exp_config
.
max_trial_number
=
10
exp_config
.
training_service
.
use_active_gpu
=
False
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
8081
,
debug
=
True
)
exp
.
run
(
exp_config
,
8081
)
test/retiarii_test/mnist/test.py
0 → 100644
View file @
ae50ed14
import
random
import
nni.retiarii.nn.pytorch
as
nn
import
torch.nn.functional
as
F
from
nni.retiarii.experiment
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.strategies
import
RandomStrategy
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
(
Net
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
20
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
nn
.
LayerChoice
([
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
),
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
,
bias
=
False
)
])
self
.
fc2
=
nn
.
Linear
(
hidden_size
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
conv2
(
x
))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
50
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
if
__name__
==
'__main__'
:
base_model
=
Net
(
128
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"MNIST"
,
dataset_kwargs
=
{
"root"
:
"data/mnist"
,
"download"
:
True
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
simple_startegy
=
RandomStrategy
()
exp
=
RetiariiExperiment
(
base_model
,
trainer
,
[],
simple_startegy
)
exp_config
=
RetiariiExeConfig
(
'local'
)
exp_config
.
experiment_name
=
'mnist_search'
exp_config
.
trial_concurrency
=
2
exp_config
.
max_trial_number
=
10
exp_config
.
training_service
.
use_active_gpu
=
False
exp
.
run
(
exp_config
,
8081
+
random
.
randint
(
0
,
100
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment