Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
349ead41
Commit
349ead41
authored
Jan 14, 2021
by
liuzhe
Browse files
Merge branch 'v2.0' into master
parents
25db55ca
649ee597
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
836 additions
and
794 deletions
+836
-794
nni/nas/pytorch/mutator.py
nni/nas/pytorch/mutator.py
+2
-2
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-1
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+2
-2
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+456
-475
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+4
-0
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+18
-15
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+21
-16
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+1
-1
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+2
-9
nni/retiarii/execution/listener.py
nni/retiarii/execution/listener.py
+0
-11
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+50
-57
nni/retiarii/graph.py
nni/retiarii/graph.py
+4
-4
nni/retiarii/integration.py
nni/retiarii/integration.py
+14
-32
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+36
-0
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+11
-6
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+155
-152
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-2
nni/retiarii/strategies/__init__.py
nni/retiarii/strategies/__init__.py
+1
-0
nni/retiarii/strategies/random_strategy.py
nni/retiarii/strategies/random_strategy.py
+32
-0
nni/retiarii/strategies/tpe_strategy.py
nni/retiarii/strategies/tpe_strategy.py
+21
-9
No files found.
nni/nas/pytorch/mutator.py
View file @
349ead41
...
...
@@ -147,7 +147,7 @@ class Mutator(BaseMutator):
Parameters
----------
mutable : LayerChoice
mutable :
nni.nas.pytorch.mutables.
LayerChoice
Layer choice module.
args : list of torch.Tensor
Inputs
...
...
@@ -180,7 +180,7 @@ class Mutator(BaseMutator):
Parameters
----------
mutable : InputChoice
mutable :
nni.nas.pytorch.mutables.
InputChoice
Input choice module.
tensor_list : list of torch.Tensor
Tensor list to apply the decision on.
...
...
nni/retiarii/__init__.py
View file @
349ead41
...
...
@@ -2,4 +2,4 @@ from .operation import Operation
from
.graph
import
*
from
.execution
import
*
from
.mutator
import
*
from
.utils
import
register_module
\ No newline at end of file
from
.utils
import
blackbox
,
blackbox_module
,
register_trainer
nni/retiarii/codegen/pytorch.py
View file @
349ead41
...
...
@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
_logger
.
info
(
'sorted_incoming_edges: %s'
,
str
(
edges
))
_logger
.
debug
(
'sorted_incoming_edges: %s'
,
str
(
edges
))
if
not
edges
:
return
[]
_logger
.
info
(
'all tail_slots are None: %s'
,
str
([
edge
.
tail_slot
for
edge
in
edges
]))
_logger
.
debug
(
'all tail_slots are None: %s'
,
str
([
edge
.
tail_slot
for
edge
in
edges
]))
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
return
edges
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
...
...
nni/retiarii/converter/graph_gen.py
View file @
349ead41
...
...
@@ -6,517 +6,501 @@ import torch
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
from
..utils
import
get_records
from
.op_types
import
MODULE_EXCEPT_LIST
,
BasicOpsPT
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
_logger
=
logging
.
getLogger
(
__name__
)
global_seq
=
0
global_graph_id
=
0
modules_arg
=
None
class
GraphConverter
:
def
__init__
(
self
):
self
.
global_seq
=
0
self
.
global_graph_id
=
0
self
.
modules_arg
=
get_records
()
def
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input
=
(
len
([
_input
for
_input
in
node
.
inputs
()])
-
(
1
if
ignore_first
else
0
))
==
1
new_node_input_idx
=
0
for
_input
in
node
.
inputs
():
if
ignore_first
:
ignore_first
=
False
continue
# handle source node
if
_input
in
graph_inputs
:
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
elif
_input
in
output_remap
:
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
predecessor_node
=
output_remap
[
_input
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
None
src_node
=
node_index
[
predecessor_node
]
assert
isinstance
(
src_node
,
Node
)
else
:
predecessor_node
=
_input
.
node
()
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs
=
[
_output
for
_output
in
predecessor_node
.
outputs
()]
if
len
(
predecessor_outputs
)
==
1
:
idx
=
None
def
_add_edge
(
self
,
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input
=
(
len
([
_input
for
_input
in
node
.
inputs
()])
-
(
1
if
ignore_first
else
0
))
==
1
new_node_input_idx
=
0
for
_input
in
node
.
inputs
():
if
ignore_first
:
ignore_first
=
False
continue
# handle source node
if
_input
in
graph_inputs
:
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
elif
_input
in
output_remap
:
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
predecessor_node
=
output_remap
[
_input
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
None
src_node
=
node_index
[
predecessor_node
]
assert
isinstance
(
src_node
,
Node
)
else
:
idx
=
predecessor_outputs
.
index
(
_input
)
ir_predecessor_node
=
node_index
[
predecessor_node
]
src_node_idx
=
idx
assert
isinstance
(
ir_predecessor_node
,
Node
)
src_node
=
ir_predecessor_node
# handle destination node
dst_node
=
new_node
if
is_single_input
:
dst_node_idx
=
None
else
:
dst_node_idx
=
new_node_input_idx
# create edge
ir_graph
.
add_edge
(
head
=
(
src_node
,
src_node_idx
),
tail
=
(
dst_node
,
dst_node_idx
))
new_node_input_idx
+=
1
def
create_prim_constant_node
(
ir_graph
,
node
,
module_name
):
global
global_seq
attrs
=
{}
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
global_seq
),
node
.
kind
(),
attrs
)
return
new_node
def
handle_prim_attr_node
(
node
):
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
return
node
.
kind
(),
attrs
def
_remove_mangle
(
module_type_str
):
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
predecessor_node
=
_input
.
node
()
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs
=
[
_output
for
_output
in
predecessor_node
.
outputs
()]
if
len
(
predecessor_outputs
)
==
1
:
idx
=
None
else
:
idx
=
predecessor_outputs
.
index
(
_input
)
ir_predecessor_node
=
node_index
[
predecessor_node
]
src_node_idx
=
idx
assert
isinstance
(
ir_predecessor_node
,
Node
)
src_node
=
ir_predecessor_node
# handle destination node
dst_node
=
new_node
if
is_single_input
:
dst_node_idx
=
None
else
:
dst_node_idx
=
new_node_input_idx
def
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
None
):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout
=
set
()
for
edge
in
ir_graph
.
edges
:
if
edge
.
head
.
id
not
in
node_fanout
:
node_fanout
.
add
(
edge
.
head
.
id
)
to_removes
=
[]
for
hidden_node
in
ir_graph
.
hidden_nodes
:
if
hidden_node
.
id
not
in
node_fanout
:
assert
isinstance
(
hidden_node
,
Node
)
if
targeted_type
is
None
:
to_removes
.
append
(
hidden_node
)
elif
hidden_node
.
operation
.
type
==
targeted_type
:
to_removes
.
append
(
hidden_node
)
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
# create edge
ir_graph
.
add_edge
(
head
=
(
src_node
,
src_node_idx
),
tail
=
(
dst_node
,
dst_node_idx
))
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
new_node_input_idx
+=
1
Returns
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs
=
[]
for
_input
in
sm_graph
.
inputs
():
if
_input
.
debugName
()
==
'self'
:
assert
_input
.
unique
()
==
0
continue
graph_inputs
.
append
(
_input
)
# TODO: add scope name
ir_graph
.
_add_input
(
_convert_name
(
_input
.
debugName
()))
node_index
=
{}
# graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap
=
{}
def
handle_if_condition
(
cond_tensor
):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
def
create_prim_constant_node
(
self
,
ir_graph
,
node
,
module_name
):
attrs
=
{}
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
self
.
global_seq
),
node
.
kind
(),
attrs
)
return
new_node
generate the expression using recursive calls
def
handle_prim_attr_node
(
self
,
node
):
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
return
node
.
kind
(),
attrs
NOTE: do not support dynamic graph
"""
def
_generate_expr
(
tensor
):
if
tensor
.
node
().
kind
()
==
'prim::GetAttr'
:
return
f
'(
{
getattr
(
module
,
tensor
.
node
().
s
(
"name"
))
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__getitem__'
:
t
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
idx
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
t
}
[
{
idx
}
])'
elif
tensor
.
node
().
kind
()
==
'prim::Constant'
:
return
f
'
{
tensor
.
toIValue
()
}
'
elif
tensor
.
node
().
kind
()
==
'aten::eq'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
==
{
right
}
)'
else
:
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition'
)
expr
=
_generate_expr
(
cond_tensor
)
return
eval
(
expr
)
def
_remove_mangle
(
self
,
module_type_str
):
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
def
handle_if_node
(
nod
e
):
def
remove_unconnected_nodes
(
self
,
ir_graph
,
targeted_type
=
Non
e
):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs
=
[
i
for
i
in
node
.
inputs
()]
assert
len
(
inputs
)
==
1
cond
=
handle_if_condition
(
inputs
[
0
])
chosen_block
=
0
if
cond
else
1
blocks
=
[
block
for
block
in
node
.
blocks
()]
assert
len
(
blocks
)
==
2
last_block_node
=
None
for
node
in
blocks
[
chosen_block
].
nodes
():
last_block_node
=
handle_single_node
(
node
)
return
last_block_node
def
handle_single_node
(
node
):
# build index of outputs of Node(s)
node_fanout
=
set
()
for
edge
in
ir_graph
.
edges
:
if
edge
.
head
.
id
not
in
node_fanout
:
node_fanout
.
add
(
edge
.
head
.
id
)
to_removes
=
[]
for
hidden_node
in
ir_graph
.
hidden_nodes
:
if
hidden_node
.
id
not
in
node_fanout
:
assert
isinstance
(
hidden_node
,
Node
)
if
targeted_type
is
None
:
to_removes
.
append
(
hidden_node
)
elif
hidden_node
.
operation
.
type
==
targeted_type
:
to_removes
.
append
(
hidden_node
)
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns
-------
Node
the
created node ir
dict
the
mapping from graph node to our graph ir node
"""
global
global_seq
if
node
.
kind
()
==
'prim::CallMethod'
:
# get and handle the first input, which should be an nn.Module
assert
node
.
hasAttribute
(
'name'
)
if
node
.
s
(
'name'
)
==
'forward'
:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str
=
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule
=
node
.
inputsAt
(
0
).
node
()
assert
submodule
.
kind
()
==
'prim::GetAttr'
assert
submodule
.
hasAttribute
(
'name'
)
submodule_name
=
submodule
.
s
(
'name'
)
if
submodule
.
inputsAt
(
0
).
debugName
()
==
'self'
:
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
# handle inputs
graph_inputs
=
[]
for
_input
in
sm_graph
.
inputs
():
if
_input
.
debugName
()
==
'self'
:
assert
_input
.
unique
()
==
0
continue
graph_inputs
.
append
(
_input
)
# TODO: add scope name
ir_graph
.
_add_input
(
_convert_name
(
_input
.
debugName
()))
node_index
=
{}
# graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap
=
{}
def
handle_if_condition
(
cond_tensor
):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def
_generate_expr
(
tensor
):
if
tensor
.
node
().
kind
()
==
'prim::GetAttr'
:
return
f
'(
{
getattr
(
module
,
tensor
.
node
().
s
(
"name"
))
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__getitem__'
:
t
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
idx
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
t
}
[
{
idx
}
])'
elif
tensor
.
node
().
kind
()
==
'prim::Constant'
:
return
f
'
{
tensor
.
toIValue
()
}
'
elif
tensor
.
node
().
kind
()
==
'aten::eq'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
==
{
right
}
)'
else
:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
assert
predecessor
.
kind
()
==
'prim::GetAttr'
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
predecessor_name
=
predecessor
.
s
(
'name'
)
# FIXME: exchange
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition'
)
expr
=
_generate_expr
(
cond_tensor
)
return
eval
(
expr
)
def
handle_if_node
(
node
):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs
=
[
i
for
i
in
node
.
inputs
()]
assert
len
(
inputs
)
==
1
cond
=
handle_if_condition
(
inputs
[
0
])
chosen_block
=
0
if
cond
else
1
blocks
=
[
block
for
block
in
node
.
blocks
()]
assert
len
(
blocks
)
==
2
last_block_node
=
None
for
node
in
blocks
[
chosen_block
].
nodes
():
last_block_node
=
handle_single_node
(
node
)
return
last_block_node
def
handle_single_node
(
node
):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if
node
.
kind
()
==
'prim::CallMethod'
:
# get and handle the first input, which should be an nn.Module
assert
node
.
hasAttribute
(
'name'
)
if
node
.
s
(
'name'
)
==
'forward'
:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule
=
node
.
inputsAt
(
0
).
node
()
assert
submodule
.
kind
()
==
'prim::GetAttr'
assert
submodule
.
hasAttribute
(
'name'
)
submodule_name
=
submodule
.
s
(
'name'
)
if
submodule
.
inputsAt
(
0
).
debugName
()
==
'self'
:
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
# TODO: match subgraph with maintained graphs
# build cell
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
assert
predecessor
.
kind
()
==
'prim::GetAttr'
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
predecessor_name
=
predecessor
.
s
(
'name'
)
# FIXME: exchange
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
# TODO: match subgraph with maintained graphs
# build cell
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
else
:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
node_index
[
node
]
=
subcell
# connect the cell into graph
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
else
:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
node_index
[
node
]
=
subcell
# connect the cell into graph
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
raise
RuntimeError
(
'unsupported CallMethod {}'
.
format
(
node
.
s
(
'name'
)))
elif
node
.
kind
()
==
'prim::CallFunction'
:
func_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func
=
node
.
inputsAt
(
0
).
node
()
assert
func
.
kind
()
==
'prim::Constant'
assert
func
.
hasAttribute
(
'name'
)
func_name
=
func
.
s
(
'name'
)
# create node for func
self
.
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
self
.
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
node_index
[
node
]
=
func_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
new_node
=
self
.
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
self
.
handle_prim_attr_node
(
node
)
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
self
.
global_seq
),
node_type
,
attrs
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::If'
:
last_block_node
=
handle_if_node
(
node
)
# last_block_node is None means no node in the branch block
node_index
[
node
]
=
last_block_node
elif
node
.
kind
()
==
'prim::Loop'
:
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise
RuntimeError
(
'Loop has not been supported yet!'
)
else
:
raise
RuntimeError
(
'unsupported CallMethod {}'
.
format
(
node
.
s
(
'name'
)))
elif
node
.
kind
()
==
'prim::CallFunction'
:
func_type_str
=
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func
=
node
.
inputsAt
(
0
).
node
()
assert
func
.
kind
()
==
'prim::Constant'
assert
func
.
hasAttribute
(
'name'
)
func_name
=
func
.
s
(
'name'
)
# create node for func
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
node_index
[
node
]
=
func_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
new_node
=
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
handle_prim_attr_node
(
node
)
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
global_seq
),
node_type
,
attrs
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::If'
:
last_block_node
=
handle_if_node
(
node
)
# last_block_node is None means no node in the branch block
node_index
[
node
]
=
last_block_node
elif
node
.
kind
()
==
'prim::Loop'
:
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise
RuntimeError
(
'Loop has not been supported yet!'
)
else
:
raise
RuntimeError
(
'Unsupported kind: {}'
.
format
(
node
.
kind
()))
return
node_index
[
node
]
for
node
in
sm_graph
.
nodes
():
handle_single_node
(
node
)
return
node_index
def
merge_aten_slices
(
ir_graph
):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes
=
[]
has_slice_node
=
False
for
node
in
ir_graph
.
hidden_nodes
:
if
node
.
operation
.
type
==
'aten::slice'
:
has_slice_node
=
True
for
pred
in
node
.
predecessors
:
if
pred
.
operation
.
type
not
in
[
'aten::slice'
,
'prim::Constant'
]:
head_slice_nodes
.
append
(
node
)
break
if
has_slice_node
:
assert
head_slice_nodes
for
head_node
in
head_slice_nodes
:
slot
=
0
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
if
len
(
head_node
.
incoming_edges
)
==
4
:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert
len
(
head_node
.
incoming_edges
)
==
5
for
edge
in
head_node
.
incoming_edges
:
edge
.
tail
=
new_slice_node
slot
+=
5
node
=
head_node
while
len
(
node
.
successors
)
==
1
and
node
.
successors
[
0
].
operation
.
type
==
'aten::slice'
:
suc_node
=
node
.
successors
[
0
]
assert
len
(
suc_node
.
incoming_edges
)
==
5
for
edge
in
suc_node
.
incoming_edges
:
if
edge
.
tail_slot
==
0
:
edge
.
remove
()
else
:
edge
.
tail
=
new_slice_node
edge
.
tail_slot
=
slot
+
edge
.
tail_slot
-
1
slot
+=
4
ir_graph
.
hidden_nodes
.
remove
(
node
)
node
=
suc_node
raise
RuntimeError
(
'Unsupported kind: {}'
.
format
(
node
.
kind
()))
for
edge
in
node
.
outgoing_edges
:
edge
.
head
=
new_slice_node
ir_graph
.
hidden_nodes
.
remove
(
node
)
return
node_index
[
node
]
for
node
in
sm_graph
.
nodes
():
handle_single_node
(
node
)
def
refine_graph
(
ir_graph
):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::Constant'
)
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
merge_aten_slices
(
ir_graph
)
return
node_index
def
merge_aten_slices
(
self
,
ir_graph
):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes
=
[]
has_slice_node
=
False
for
node
in
ir_graph
.
hidden_nodes
:
if
node
.
operation
.
type
==
'aten::slice'
:
has_slice_node
=
True
for
pred
in
node
.
predecessors
:
if
pred
.
operation
.
type
not
in
[
'aten::slice'
,
'prim::Constant'
]:
head_slice_nodes
.
append
(
node
)
break
if
has_slice_node
:
assert
head_slice_nodes
for
head_node
in
head_slice_nodes
:
slot
=
0
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
if
len
(
head_node
.
incoming_edges
)
==
4
:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert
len
(
head_node
.
incoming_edges
)
==
5
for
edge
in
head_node
.
incoming_edges
:
edge
.
tail
=
new_slice_node
slot
+=
5
node
=
head_node
while
len
(
node
.
successors
)
==
1
and
node
.
successors
[
0
].
operation
.
type
==
'aten::slice'
:
suc_node
=
node
.
successors
[
0
]
assert
len
(
suc_node
.
incoming_edges
)
==
5
for
edge
in
suc_node
.
incoming_edges
:
if
edge
.
tail_slot
==
0
:
edge
.
remove
()
else
:
edge
.
tail
=
new_slice_node
edge
.
tail_slot
=
slot
+
edge
.
tail_slot
-
1
slot
+=
4
ir_graph
.
hidden_nodes
.
remove
(
node
)
node
=
suc_node
for
edge
in
node
.
outgoing_edges
:
edge
.
head
=
new_slice_node
ir_graph
.
hidden_nodes
.
remove
(
node
)
def
refine_graph
(
self
,
ir_graph
):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::Constant'
)
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
self
.
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
self
,
module
):
m_attrs
=
{}
candidates
=
module
.
op_candidates
choices
=
[]
for
cand
in
candidates
:
assert
id
(
cand
)
in
self
.
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
self
.
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
self
.
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
_handle_inputchoice
(
self
,
module
):
m_attrs
=
{}
m_attrs
[
'n_candidates'
]
=
module
.
n_candidates
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
def
_handle_layerchoice
(
module
):
global
modules_arg
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
m_attrs
=
{}
candidates
=
module
.
candidate_ops
choices
=
[]
for
cand
in
candidates
:
assert
id
(
cand
)
in
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
m_attrs
=
None
if
original_type_name
in
MODULE_EXCEPT_LIST
:
pass
# do nothing
elif
original_type_name
==
OpTypeName
.
LayerChoice
:
m_attrs
=
self
.
_handle_layerchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
self
.
_handle_inputchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
elif
original_type_name
in
torch
.
nn
.
__dict__
:
# this is a basic module from pytorch, no need to parse its graph
assert
id
(
module
)
in
self
.
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
elif
id
(
module
)
in
self
.
modules_arg
:
# this module is marked as blackbox, won't continue to parse
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
if
m_attrs
is
not
None
:
return
None
,
m_attrs
# handle TorchScript graph
sm_graph
=
script_module
.
graph
self
.
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
self
.
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
# handle graph nodes
node_index
=
self
.
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
)
# handle graph outputs
for
_output
in
sm_graph
.
outputs
():
ir_graph
.
_add_output
(
_convert_name
(
_output
.
debugName
()))
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
if
len
(
predecessor_node_outputs
)
==
1
:
src_node_idx
=
None
else
:
src_node_idx
=
predecessor_node_outputs
.
index
(
_output
)
ir_graph
.
add_edge
(
head
=
(
node_index
[
_output
.
node
()],
src_node_idx
),
tail
=
(
ir_graph
.
output_node
,
None
))
def
_handle_inputchoice
(
module
):
m_attrs
=
{}
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
self
.
refine_graph
(
ir_graph
)
ir_graph
.
_register
()
def
convert_module
(
script_module
,
module
,
module_name
,
ir_model
):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
return
ir_graph
,
{}
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
global
global_graph_id
global
modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
if
original_type_name
==
OpTypeName
.
LayerChoice
:
m_attrs
=
_handle_layerchoice
(
module
)
return
None
,
m_attrs
if
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
_handle_inputchoice
(
module
)
return
None
,
m_attrs
if
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
if
original_type_name
in
torch
.
nn
.
__dict__
and
original_type_name
not
in
MODULE_EXCEPT_LIST
:
# this is a basic module from pytorch, no need to parse its graph
assert
id
(
module
)
in
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
# handle TorchScript graph
sm_graph
=
script_module
.
graph
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
# handle graph nodes
node_index
=
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
)
# handle graph outputs
for
_output
in
sm_graph
.
outputs
():
ir_graph
.
_add_output
(
_convert_name
(
_output
.
debugName
()))
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
if
len
(
predecessor_node_outputs
)
==
1
:
src_node_idx
=
None
else
:
src_node_idx
=
predecessor_node_outputs
.
index
(
_output
)
ir_graph
.
add_edge
(
head
=
(
node_index
[
_output
.
node
()],
src_node_idx
),
tail
=
(
ir_graph
.
output_node
,
None
))
refine_graph
(
ir_graph
)
ir_graph
.
_register
()
if
id
(
module
)
not
in
modules_arg
:
raise
RuntimeError
(
f
'
{
original_type_name
}
arguments are not recorded,
\
you might have forgotten to decorate this class with @register_module()'
)
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return
ir_graph
,
{}
def
convert_to_graph
(
script_module
,
module
,
recorded_modules_arg
):
def
convert_to_graph
(
script_module
,
module
):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
...
...
@@ -526,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns
-------
Model
the constructed IR model
"""
global
modules_arg
modules_arg
=
recorded_modules_arg
model
=
Model
(
_internal
=
True
)
module_name
=
'_model'
convert_module
(
script_module
,
module
,
module_name
,
model
)
GraphConverter
().
convert_module
(
script_module
,
module
,
module_name
,
model
)
return
model
nni/retiarii/converter/op_types.py
View file @
349ead41
...
...
@@ -30,6 +30,10 @@ BasicOpsPT = {
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
,
'aten::eq'
:
'Eq'
,
'aten::Bool'
:
'Bool'
,
'aten::empty'
:
'Empty'
,
'aten::zeros'
:
'Zeros'
,
'aten::chunk'
:
'Chunk'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
...
...
nni/retiarii/execution/api.py
View file @
349ead41
import
time
import
os
from
typing
import
List
from
..graph
import
Model
,
ModelStatus
from
.base
import
BaseExecutionEngine
from
.cgo_engine
import
CGOExecutionEngine
from
.interface
import
AbstractExecutionEngine
,
WorkerInfo
from
.interface
import
AbstractExecutionEngine
from
.listener
import
DefaultListener
_execution_engine
=
None
_default_listener
=
None
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
]
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'set_execution_engine'
,
'is_stopped_exec'
]
def
set_execution_engine
(
engine
)
->
None
:
global
_execution_engine
if
_execution_engine
is
None
:
_execution_engine
=
engine
else
:
raise
RuntimeError
(
'execution engine is already set'
)
def
get_execution_engine
()
->
Base
ExecutionEngine
:
def
get_execution_engine
()
->
Abstract
ExecutionEngine
:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
"""
global
_execution_engine
if
_execution_engine
is
None
:
if
os
.
environ
.
get
(
'CGO'
)
==
'true'
:
_execution_engine
=
CGOExecutionEngine
()
else
:
_execution_engine
=
BaseExecutionEngine
()
return
_execution_engine
...
...
@@ -51,6 +49,11 @@ def wait_models(*models: Model) -> None:
break
def
query_available_resources
()
->
List
[
WorkerInfo
]:
listener
=
get_and_register_default_listener
(
get_execution_engine
())
return
listener
.
resources
def
query_available_resources
()
->
int
:
engine
=
get_execution_engine
()
resources
=
engine
.
query_available_resource
()
return
resources
if
isinstance
(
resources
,
int
)
else
len
(
resources
)
def
is_stopped_exec
(
model
:
Model
)
->
bool
:
return
model
.
status
in
(
ModelStatus
.
Trained
,
ModelStatus
.
Failed
)
nni/retiarii/execution/base.py
View file @
349ead41
import
logging
import
os
import
random
import
string
from
typing
import
Dict
,
Any
,
List
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..integration
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
..integration
_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -29,7 +32,7 @@ class BaseGraphData:
class
BaseExecutionEngine
(
AbstractExecutionEngine
):
"""
The execution engine with no optimization at all.
Resource management is
yet to be
implemented.
Resource management is implemented
in this class
.
"""
def
__init__
(
self
)
->
None
:
...
...
@@ -50,6 +53,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
resources
=
0
def
submit_models
(
self
,
*
models
:
Model
)
->
None
:
for
model
in
models
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
...
...
@@ -60,17 +65,14 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self
.
_listeners
.
append
(
listener
)
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
for
listener
in
self
.
_listeners
:
_logger
.
warning
(
'resources: %s'
,
listener
.
resources
)
if
not
listener
.
has_available_resource
():
_logger
.
warning
(
'There is no available resource, but trial is submitted.'
)
listener
.
on_resource_used
(
1
)
_logger
.
warning
(
'on_resource_used: %s'
,
listener
.
resources
)
if
self
.
resources
<=
0
:
_logger
.
warning
(
'There is no available resource, but trial is submitted.'
)
self
.
resources
-=
1
_logger
.
info
(
'on_resource_used: %d'
,
self
.
resources
)
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_available
(
1
*
num_trials
)
_logger
.
warning
(
'on_resource_available: %s'
,
listener
.
resources
)
self
.
resources
+=
num_trials
_logger
.
info
(
'on_resource_available: %d'
,
self
.
resources
)
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
model
=
self
.
_running_models
[
trial_id
]
...
...
@@ -93,8 +95,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
model
,
metrics
)
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]
:
r
aise
NotImplementedError
# move the method from listener to here?
def
query_available_resource
(
self
)
->
int
:
r
eturn
self
.
resources
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
...
...
@@ -102,9 +104,12 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer.
"""
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
with
open
(
'_generated_model.py'
,
'w'
)
as
f
:
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
file_name
=
f
'_generated_model_
{
random_str
}
.py'
with
open
(
file_name
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
model_cls
=
utils
.
import_
(
'_generated_model._model'
)
model_cls
=
utils
.
import_
(
f
'_generated_model
_
{
random_str
}
._model'
)
trainer_instance
=
trainer_cls
(
model
=
model_cls
(),
**
graph_data
.
training_kwargs
)
trainer_instance
.
fit
()
os
.
remove
(
file_name
)
\ No newline at end of file
nni/retiarii/execution/cgo_engine.py
View file @
349ead41
...
...
@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..integration
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
..integration
_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
PhysicalDevice
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
...
...
nni/retiarii/execution/interface.py
View file @
349ead41
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
typing
import
Any
,
NewType
,
List
from
typing
import
Any
,
NewType
,
List
,
Union
from
..graph
import
Model
,
MetricData
...
...
@@ -59,13 +59,6 @@ class AbstractGraphListener(ABC):
"""
pass
@
abstractmethod
def
on_resource_available
(
self
,
resources
:
List
[
WorkerInfo
])
->
None
:
"""
Reports when a worker becomes idle.
"""
pass
class
AbstractExecutionEngine
(
ABC
):
"""
...
...
@@ -109,7 +102,7 @@ class AbstractExecutionEngine(ABC):
raise
NotImplementedError
@
abstractmethod
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
]
,
int
]
:
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
...
...
nni/retiarii/execution/listener.py
View file @
349ead41
...
...
@@ -3,11 +3,6 @@ from .interface import MetricData, AbstractGraphListener
class
DefaultListener
(
AbstractGraphListener
):
def
__init__
(
self
):
self
.
resources
:
int
=
0
# simply resource count
def
has_available_resource
(
self
)
->
bool
:
return
self
.
resources
>
0
def
on_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
model
.
metric
=
metric
...
...
@@ -20,9 +15,3 @@ class DefaultListener(AbstractGraphListener):
model
.
status
=
ModelStatus
.
Trained
else
:
model
.
status
=
ModelStatus
.
Failed
def
on_resource_available
(
self
,
resources
:
int
)
->
None
:
self
.
resources
+=
resources
def
on_resource_used
(
self
,
resources
:
int
)
->
None
:
self
.
resources
-=
resources
nni/retiarii/experiment.py
View file @
349ead41
import
logging
import
time
from
dataclasses
import
dataclass
from
pathlib
import
Path
...
...
@@ -7,20 +6,24 @@ from subprocess import Popen
from
threading
import
Thread
from
typing
import
Any
,
Optional
from
..experiment
import
Experiment
,
TrainingServiceConfig
,
launcher
,
rest
from
..experiment
import
Experiment
,
TrainingServiceConfig
from
..experiment.config.base
import
ConfigBase
,
PathLike
from
..experiment.config
import
util
from
..experiment.pipe
import
Pipe
from
.graph
import
Model
from
.utils
import
get_records
from
.integration
import
RetiariiAdvisor
from
.converter
import
convert_to_graph
from
.mutator
import
Mutator
,
LayerChoiceMutator
,
InputChoiceMutator
from
.trainer.interface
import
BaseTrainer
from
.trainer.interface
import
BaseTrainer
,
BaseOneShotTrainer
from
.strategies.strategy
import
BaseStrategy
from
.trainer.pytorch
import
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
_logger
=
logging
.
getLogger
(
__name__
)
OneShotTrainers
=
(
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
...
...
@@ -43,7 +46,7 @@ class RetiariiExeConfig(ConfigBase):
super
().
__init__
(
**
kwargs
)
if
training_service_platform
is
not
None
:
assert
'training_service'
not
in
kwargs
self
.
training_service
=
util
.
training_service_config_factory
(
training_service_platform
)
self
.
training_service
=
util
.
training_service_config_factory
(
platform
=
training_service_platform
)
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
super
().
validate
()
...
...
@@ -76,7 +79,7 @@ _validation_rules = {
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
Model
,
trainer
:
BaseTrainer
,
applied_mutators
:
Mutator
,
strategy
:
BaseStrategy
):
applied_mutators
:
Mutator
=
None
,
strategy
:
BaseStrategy
=
None
):
self
.
config
:
RetiariiExeConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
...
...
@@ -87,6 +90,7 @@ class RetiariiExperiment(Experiment):
self
.
recorded_module_args
=
get_records
()
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
...
...
@@ -103,7 +107,10 @@ class RetiariiExperiment(Experiment):
mutator
=
LayerChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'choices'
])
applied_mutators
.
append
(
mutator
)
for
node
in
ic_nodes
:
mutator
=
InputChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'n_chosen'
])
mutator
=
InputChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'n_candidates'
],
node
.
operation
.
parameters
[
'n_chosen'
],
node
.
operation
.
parameters
[
'reduction'
])
applied_mutators
.
append
(
mutator
)
return
applied_mutators
...
...
@@ -114,14 +121,17 @@ class RetiariiExperiment(Experiment):
except
Exception
as
e
:
_logger
.
error
(
'Your base model cannot be parsed by torch.jit.script, please fix the following error:'
)
raise
e
base_model
=
convert_to_graph
(
script_module
,
self
.
base_model
,
self
.
recorded_module_args
)
base_model
_ir
=
convert_to_graph
(
script_module
,
self
.
base_model
)
assert
id
(
self
.
trainer
)
in
self
.
recorded_module_args
trainer_config
=
self
.
recorded_module_args
[
id
(
self
.
trainer
)]
base_model
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
recorded_module_args
=
get_records
()
if
id
(
self
.
trainer
)
not
in
recorded_module_args
:
raise
KeyError
(
'Your trainer is not found in registered classes. You might have forgotten to
\
register your customized trainer with @register_trainer decorator.'
)
trainer_config
=
recorded_module_args
[
id
(
self
.
trainer
)]
base_model_ir
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
# handle inline mutations
mutators
=
self
.
_process_inline_mutation
(
base_model
)
mutators
=
self
.
_process_inline_mutation
(
base_model
_ir
)
if
mutators
is
not
None
and
self
.
applied_mutators
:
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators,
\
do not use mutators when you use LayerChoice/InputChoice'
)
...
...
@@ -129,10 +139,10 @@ class RetiariiExperiment(Experiment):
self
.
applied_mutators
=
mutators
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
)).
start
()
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
_ir
,
self
.
applied_mutators
)).
start
()
_logger
.
info
(
'Strategy started!'
)
def
start
(
self
,
config
:
RetiariiExeConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
Start the experiment in background.
This method will raise exception on failure.
...
...
@@ -144,54 +154,37 @@ class RetiariiExperiment(Experiment):
debug
Whether to start in debug mode.
"""
# FIXME:
if
debug
:
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
config
,
port
,
debug
)
assert
self
.
_proc
is
not
None
assert
self
.
_pipe
is
not
None
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread
(
target
=
self
.
_dispatcher
.
run
).
start
()
super
().
start
(
port
,
debug
)
self
.
_start_strategy
()
# TODO: register experiment management metadata
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
def
stop
(
self
)
->
None
:
"""
Stop background experiment.
"""
self
.
_proc
.
kill
()
self
.
_pipe
.
close
()
self
.
port
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
def
run
(
self
,
config
:
RetiariiExeConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self
.
config
=
config
self
.
start
(
config
,
port
,
debug
)
try
:
while
True
:
time
.
sleep
(
10
)
status
=
self
.
get_status
()
# TODO: double check the status
if
status
in
[
'ERROR'
,
'STOPPED'
,
'NO_MORE_TRIAL'
]:
return
status
finally
:
self
.
stop
()
def
get_status
(
self
)
->
str
:
if
self
.
port
is
None
:
raise
RuntimeError
(
'Experiment is not running'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
return
resp
[
'status'
]
if
isinstance
(
self
.
trainer
,
OneShotTrainers
):
self
.
trainer
.
fit
()
else
:
assert
config
is
not
None
,
'You are using classic search mode, config cannot be None!'
self
.
config
=
config
super
().
run
(
port
,
debug
)
def
export_top_models
(
self
,
top_n
:
int
=
1
):
"""
export several top performing models
"""
if
top_n
!=
1
:
_logger
.
warning
(
'Only support top_n is 1 for now.'
)
if
isinstance
(
self
.
trainer
,
BaseOneShotTrainer
):
return
self
.
trainer
.
export
()
else
:
_logger
.
info
(
'For this experiment, you can find out the best one from WebUI.'
)
def
retrain_model
(
self
,
model
):
"""
this function retrains the exported model, and test it to output test accuracy
"""
raise
NotImplementedError
nni/retiarii/graph.py
View file @
349ead41
...
...
@@ -594,10 +594,10 @@ class Edge:
Example forward code snippet:
```
a, b, c = split(x)
p = concat(a, c)
q = sum(b, p)
z = relu(q)
a, b, c = split(x)
p = concat(a, c)
q = sum(b, p)
z = relu(q)
```
Edges in above snippet:
...
...
nni/retiarii/integration.py
View file @
349ead41
import
logging
import
os
from
typing
import
Any
,
Callable
import
json_tricks
import
nni
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
from
.graph
import
MetricData
from
.execution.base
import
BaseExecutionEngine
from
.execution.cgo_engine
import
CGOExecutionEngine
from
.execution.api
import
set_execution_engine
from
.integration_api
import
register_advisor
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -55,6 +59,15 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
parameters_count
=
0
engine
=
self
.
_create_execution_engine
()
set_execution_engine
(
engine
)
def
_create_execution_engine
(
self
):
if
os
.
environ
.
get
(
'CGO'
)
==
'true'
:
return
CGOExecutionEngine
()
else
:
return
BaseExecutionEngine
()
def
handle_initialize
(
self
,
data
):
"""callback for initializing the advisor
Parameters
...
...
@@ -126,34 +139,3 @@ class RetiariiAdvisor(MsgDispatcherBase):
else
:
return
value
return
value
_advisor
:
RetiariiAdvisor
=
None
def
get_advisor
()
->
RetiariiAdvisor
:
global
_advisor
assert
_advisor
is
not
None
return
_advisor
def
register_advisor
(
advisor
:
RetiariiAdvisor
):
global
_advisor
assert
_advisor
is
None
_advisor
=
advisor
def
send_trial
(
parameters
:
dict
)
->
int
:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return
get_advisor
().
send_trial
(
parameters
)
def
receive_trial_parameters
()
->
dict
:
"""
Received a new trial. Executed on trial end.
"""
params
=
nni
.
get_next_parameter
()
return
params
nni/retiarii/integration_api.py
0 → 100644
View file @
349ead41
from
typing
import
NewType
,
Any
import
nni
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
_advisor
:
'RetiariiAdvisor'
=
None
def
get_advisor
()
->
'RetiariiAdvisor'
:
global
_advisor
assert
_advisor
is
not
None
return
_advisor
def
register_advisor
(
advisor
:
'RetiariiAdvisor'
):
global
_advisor
assert
_advisor
is
None
_advisor
=
advisor
def
send_trial
(
parameters
:
dict
)
->
int
:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return
get_advisor
().
send_trial
(
parameters
)
def
receive_trial_parameters
()
->
dict
:
"""
Received a new trial. Executed on trial end.
"""
params
=
nni
.
get_next_parameter
()
return
params
nni/retiarii/mutator.py
View file @
349ead41
...
...
@@ -28,8 +28,10 @@ class Mutator:
"""
Mutates graphs in model to generate new model.
`Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model.
...
...
@@ -104,6 +106,7 @@ class _RecorderSampler(Sampler):
self
.
recorded_candidates
.
append
(
candidates
)
return
candidates
[
0
]
# the following is for inline mutation
...
...
@@ -122,14 +125,16 @@ class LayerChoiceMutator(Mutator):
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
n_c
hosen
:
int
):
def
__init__
(
self
,
node_name
:
str
,
n_c
andidates
:
int
,
n_chosen
:
int
,
reduction
:
str
):
super
().
__init__
()
self
.
node_name
=
node_name
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
def
mutate
(
self
,
model
):
target
=
model
.
get_node_by_name
(
self
.
node_name
)
candidates
=
[
i
for
i
in
range
(
self
.
n_c
hosen
)]
chosen
=
self
.
choice
(
candidates
)
candidates
=
[
i
for
i
in
range
(
self
.
n_c
andidates
)]
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
self
.
n_chosen
)]
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs'
,
{
'chosen'
:
chosen
})
{
'chosen'
:
chosen
,
'reduction'
:
self
.
reduction
})
nni/retiarii/nn/pytorch/nn.py
View file @
349ead41
import
inspect
import
logging
from
typing
import
Any
,
List
import
torch
import
torch.nn
as
nn
from
...utils
import
add_record
from
...utils
import
add_record
,
blackbox_module
,
uid
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
# NOTE: support pytorch version >= 1.5.0
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
...
...
@@ -29,18 +30,24 @@ __all__ = [
'ConstantPad3d'
,
'Bilinear'
,
'CosineSimilarity'
,
'Unfold'
,
'Fold'
,
'AdaptiveLogSoftmaxWithLoss'
,
'TransformerEncoder'
,
'TransformerDecoder'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten'
,
'Hardsigmoid'
,
'Hardswish'
'Flatten'
,
'Hardsigmoid'
]
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
__all__
.
append
(
'Hardswish'
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
class
LayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
op_candidates
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
super
(
LayerChoice
,
self
).
__init__
()
self
.
candidate_ops
=
op_candidates
self
.
label
=
key
self
.
op_candidates
=
op_candidates
self
.
label
=
key
if
key
is
not
None
else
f
'layerchoice_
{
uid
()
}
'
self
.
key
=
self
.
label
# deprecated, for backward compatibility
for
i
,
module
in
enumerate
(
op_candidates
):
# deprecated, for backward compatibility
self
.
add_module
(
str
(
i
),
module
)
if
reduction
or
return_mask
:
_logger
.
warning
(
'input arguments `reduction` and `return_mask` are deprecated!'
)
...
...
@@ -52,10 +59,12 @@ class InputChoice(nn.Module):
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
1
,
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
super
(
InputChoice
,
self
).
__init__
()
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
label
=
key
if
n_candidates
or
choose_from
or
return_mask
:
self
.
label
=
key
if
key
is
not
None
else
f
'inputchoice_
{
uid
()
}
'
self
.
key
=
self
.
label
# deprecated, for backward compatibility
if
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
def
forward
(
self
,
candidate_inputs
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
...
...
@@ -86,20 +95,37 @@ class Placeholder(nn.Module):
class
ChosenInputs
(
nn
.
Module
):
def
__init__
(
self
,
chosen
:
int
):
"""
"""
def
__init__
(
self
,
chosen
:
List
[
int
],
reduction
:
str
):
super
().
__init__
()
self
.
chosen
=
chosen
self
.
reduction
=
reduction
def
forward
(
self
,
candidate_inputs
):
# TODO: support multiple chosen inputs
return
candidate_inputs
[
self
.
chosen
]
return
self
.
_tensor_reduction
(
self
.
reduction
,
[
candidate_inputs
[
i
]
for
i
in
self
.
chosen
])
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
"none"
:
return
tensor_list
if
not
tensor_list
:
return
None
# empty. return None for now
if
len
(
tensor_list
)
==
1
:
return
tensor_list
[
0
]
if
reduction_type
==
"sum"
:
return
sum
(
tensor_list
)
if
reduction_type
==
"mean"
:
return
sum
(
tensor_list
)
/
len
(
tensor_list
)
if
reduction_type
==
"concat"
:
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
# the following are pytorch modules
class
Module
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Module
,
self
).
__init__
()
Module
=
nn
.
Module
class
Sequential
(
nn
.
Sequential
):
...
...
@@ -114,139 +140,116 @@ class ModuleList(nn.ModuleList):
super
(
ModuleList
,
self
).
__init__
(
*
args
)
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
# TODO: support different versions of pytorch
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
ELU
=
wrap_module
(
nn
.
ELU
)
SELU
=
wrap_module
(
nn
.
SELU
)
CELU
=
wrap_module
(
nn
.
CELU
)
GLU
=
wrap_module
(
nn
.
GLU
)
GELU
=
wrap_module
(
nn
.
GELU
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
RNN
=
wrap_module
(
nn
.
RNN
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
GRU
=
wrap_module
(
nn
.
GRU
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
Flatten
=
wrap_module
(
nn
.
Flatten
)
#Unflatten = wrap_module(nn.Unflatten)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
#SiLU = wrap_module(nn.SiLU)
#TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
Identity
=
blackbox_module
(
nn
.
Identity
)
Linear
=
blackbox_module
(
nn
.
Linear
)
Conv1d
=
blackbox_module
(
nn
.
Conv1d
)
Conv2d
=
blackbox_module
(
nn
.
Conv2d
)
Conv3d
=
blackbox_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
blackbox_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
blackbox_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
blackbox_module
(
nn
.
ConvTranspose3d
)
Threshold
=
blackbox_module
(
nn
.
Threshold
)
ReLU
=
blackbox_module
(
nn
.
ReLU
)
Hardtanh
=
blackbox_module
(
nn
.
Hardtanh
)
ReLU6
=
blackbox_module
(
nn
.
ReLU6
)
Sigmoid
=
blackbox_module
(
nn
.
Sigmoid
)
Tanh
=
blackbox_module
(
nn
.
Tanh
)
Softmax
=
blackbox_module
(
nn
.
Softmax
)
Softmax2d
=
blackbox_module
(
nn
.
Softmax2d
)
LogSoftmax
=
blackbox_module
(
nn
.
LogSoftmax
)
ELU
=
blackbox_module
(
nn
.
ELU
)
SELU
=
blackbox_module
(
nn
.
SELU
)
CELU
=
blackbox_module
(
nn
.
CELU
)
GLU
=
blackbox_module
(
nn
.
GLU
)
GELU
=
blackbox_module
(
nn
.
GELU
)
Hardshrink
=
blackbox_module
(
nn
.
Hardshrink
)
LeakyReLU
=
blackbox_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
blackbox_module
(
nn
.
LogSigmoid
)
Softplus
=
blackbox_module
(
nn
.
Softplus
)
Softshrink
=
blackbox_module
(
nn
.
Softshrink
)
MultiheadAttention
=
blackbox_module
(
nn
.
MultiheadAttention
)
PReLU
=
blackbox_module
(
nn
.
PReLU
)
Softsign
=
blackbox_module
(
nn
.
Softsign
)
Softmin
=
blackbox_module
(
nn
.
Softmin
)
Tanhshrink
=
blackbox_module
(
nn
.
Tanhshrink
)
RReLU
=
blackbox_module
(
nn
.
RReLU
)
AvgPool1d
=
blackbox_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
blackbox_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
blackbox_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
blackbox_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
blackbox_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
blackbox_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
blackbox_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
blackbox_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
blackbox_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
blackbox_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
blackbox_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
blackbox_module
(
nn
.
LPPool1d
)
LPPool2d
=
blackbox_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
blackbox_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
blackbox_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
blackbox_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
blackbox_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
blackbox_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
blackbox_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
blackbox_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
blackbox_module
(
nn
.
LayerNorm
)
GroupNorm
=
blackbox_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
blackbox_module
(
nn
.
SyncBatchNorm
)
Dropout
=
blackbox_module
(
nn
.
Dropout
)
Dropout2d
=
blackbox_module
(
nn
.
Dropout2d
)
Dropout3d
=
blackbox_module
(
nn
.
Dropout3d
)
AlphaDropout
=
blackbox_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
blackbox_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
blackbox_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
blackbox_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
blackbox_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
blackbox_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
blackbox_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
blackbox_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
blackbox_module
(
nn
.
Embedding
)
EmbeddingBag
=
blackbox_module
(
nn
.
EmbeddingBag
)
RNNBase
=
blackbox_module
(
nn
.
RNNBase
)
RNN
=
blackbox_module
(
nn
.
RNN
)
LSTM
=
blackbox_module
(
nn
.
LSTM
)
GRU
=
blackbox_module
(
nn
.
GRU
)
RNNCellBase
=
blackbox_module
(
nn
.
RNNCellBase
)
RNNCell
=
blackbox_module
(
nn
.
RNNCell
)
LSTMCell
=
blackbox_module
(
nn
.
LSTMCell
)
GRUCell
=
blackbox_module
(
nn
.
GRUCell
)
PixelShuffle
=
blackbox_module
(
nn
.
PixelShuffle
)
Upsample
=
blackbox_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
blackbox_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
blackbox_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
blackbox_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
blackbox_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
blackbox_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
blackbox_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
blackbox_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
blackbox_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
blackbox_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
blackbox_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
blackbox_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
blackbox_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
blackbox_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
blackbox_module
(
nn
.
ConstantPad3d
)
Bilinear
=
blackbox_module
(
nn
.
Bilinear
)
CosineSimilarity
=
blackbox_module
(
nn
.
CosineSimilarity
)
Unfold
=
blackbox_module
(
nn
.
Unfold
)
Fold
=
blackbox_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
blackbox_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
blackbox_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
blackbox_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
blackbox_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
blackbox_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
blackbox_module
(
nn
.
Transformer
)
Flatten
=
blackbox_module
(
nn
.
Flatten
)
Hardsigmoid
=
blackbox_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
blackbox_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
blackbox_module
(
nn
.
SiLU
)
Unflatten
=
blackbox_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
blackbox_module
(
nn
.
TripletMarginWithDistanceLoss
)
nni/retiarii/operation.py
View file @
349ead41
...
...
@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
value
}
'
elif
self
.
type
==
'prim::ListConstruct'
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'prim::GetAttr'
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
...
...
@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
return
f
'
{
output
}
= '
+
' + '
.
join
(
inputs
)
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
...
...
@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
elif
self
.
type
==
'aten::Bool'
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
...
nni/retiarii/strategies/__init__.py
View file @
349ead41
from
.tpe_strategy
import
TPEStrategy
from
.random_strategy
import
RandomStrategy
nni/retiarii/strategies/random_strategy.py
0 → 100644
View file @
349ead41
import
logging
import
random
import
time
from
..
import
Sampler
,
submit_models
,
query_available_resources
from
.strategy
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
class
RandomSampler
(
Sampler
):
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
return
random
.
choice
(
candidates
)
class
RandomStrategy
(
BaseStrategy
):
def
__init__
(
self
):
self
.
random_sampler
=
RandomSampler
()
def
run
(
self
,
base_model
,
applied_mutators
):
_logger
.
info
(
'stargety start...'
)
while
True
:
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: %s'
,
str
(
applied_mutators
))
for
mutator
in
applied_mutators
:
mutator
.
bind_sampler
(
self
.
random_sampler
)
model
=
mutator
.
apply
(
model
)
# run models
submit_models
(
model
)
else
:
time
.
sleep
(
2
)
nni/retiarii/strategies/tpe_strategy.py
View file @
349ead41
import
logging
import
time
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
wait_models
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
from
.strategy
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -39,6 +40,7 @@ class TPEStrategy(BaseStrategy):
def
__init__
(
self
):
self
.
tpe_sampler
=
TPESampler
()
self
.
model_id
=
0
self
.
running_models
=
{}
def
run
(
self
,
base_model
,
applied_mutators
):
sample_space
=
[]
...
...
@@ -48,9 +50,10 @@ class TPEStrategy(BaseStrategy):
sample_space
.
extend
(
recorded_candidates
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
try
:
_logger
.
info
(
'stargety start...'
)
while
True
:
_logger
.
info
(
'stargety start...'
)
while
True
:
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: %s'
,
str
(
applied_mutators
))
...
...
@@ -61,9 +64,18 @@ class TPEStrategy(BaseStrategy):
model
=
mutator
.
apply
(
model
)
# run models
submit_models
(
model
)
wait_models
(
model
)
self
.
tpe_sampler
.
receive_result
(
self
.
model_id
,
model
.
metric
)
self
.
running_models
[
self
.
model_id
]
=
model
self
.
model_id
+=
1
_logger
.
info
(
'Strategy says: %s'
,
model
.
metric
)
except
Exception
:
_logger
.
error
(
logging
.
exception
(
'message'
))
else
:
time
.
sleep
(
2
)
_logger
.
warning
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
warning
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
Prev
1
…
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment