Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4784cc6c
Unverified
Commit
4784cc6c
authored
Jan 14, 2021
by
liuzhe-lz
Committed by
GitHub
Jan 14, 2021
Browse files
Merge pull request #3302 from microsoft/v2.0-merge
Merge branch v2.0 into master (no squash)
parents
25db55ca
349ead41
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
836 additions
and
794 deletions
+836
-794
nni/nas/pytorch/mutator.py
nni/nas/pytorch/mutator.py
+2
-2
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-1
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+2
-2
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+456
-475
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+4
-0
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+18
-15
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+21
-16
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+1
-1
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+2
-9
nni/retiarii/execution/listener.py
nni/retiarii/execution/listener.py
+0
-11
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+50
-57
nni/retiarii/graph.py
nni/retiarii/graph.py
+4
-4
nni/retiarii/integration.py
nni/retiarii/integration.py
+14
-32
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+36
-0
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+11
-6
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+155
-152
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-2
nni/retiarii/strategies/__init__.py
nni/retiarii/strategies/__init__.py
+1
-0
nni/retiarii/strategies/random_strategy.py
nni/retiarii/strategies/random_strategy.py
+32
-0
nni/retiarii/strategies/tpe_strategy.py
nni/retiarii/strategies/tpe_strategy.py
+21
-9
No files found.
nni/nas/pytorch/mutator.py
View file @
4784cc6c
...
...
@@ -147,7 +147,7 @@ class Mutator(BaseMutator):
Parameters
----------
mutable : LayerChoice
mutable :
nni.nas.pytorch.mutables.
LayerChoice
Layer choice module.
args : list of torch.Tensor
Inputs
...
...
@@ -180,7 +180,7 @@ class Mutator(BaseMutator):
Parameters
----------
mutable : InputChoice
mutable :
nni.nas.pytorch.mutables.
InputChoice
Input choice module.
tensor_list : list of torch.Tensor
Tensor list to apply the decision on.
...
...
nni/retiarii/__init__.py
View file @
4784cc6c
...
...
@@ -2,4 +2,4 @@ from .operation import Operation
from
.graph
import
*
from
.execution
import
*
from
.mutator
import
*
from
.utils
import
register_module
\ No newline at end of file
from
.utils
import
blackbox
,
blackbox_module
,
register_trainer
nni/retiarii/codegen/pytorch.py
View file @
4784cc6c
...
...
@@ -19,10 +19,10 @@ def model_to_pytorch_script(model: Model, placement=None) -> str:
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
_logger
.
info
(
'sorted_incoming_edges: %s'
,
str
(
edges
))
_logger
.
debug
(
'sorted_incoming_edges: %s'
,
str
(
edges
))
if
not
edges
:
return
[]
_logger
.
info
(
'all tail_slots are None: %s'
,
str
([
edge
.
tail_slot
for
edge
in
edges
]))
_logger
.
debug
(
'all tail_slots are None: %s'
,
str
([
edge
.
tail_slot
for
edge
in
edges
]))
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
return
edges
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
...
...
nni/retiarii/converter/graph_gen.py
View file @
4784cc6c
...
...
@@ -6,517 +6,501 @@ import torch
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
from
..utils
import
get_records
from
.op_types
import
MODULE_EXCEPT_LIST
,
BasicOpsPT
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
_logger
=
logging
.
getLogger
(
__name__
)
global_seq
=
0
global_graph_id
=
0
modules_arg
=
None
class
GraphConverter
:
def
__init__
(
self
):
self
.
global_seq
=
0
self
.
global_graph_id
=
0
self
.
modules_arg
=
get_records
()
def
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input
=
(
len
([
_input
for
_input
in
node
.
inputs
()])
-
(
1
if
ignore_first
else
0
))
==
1
new_node_input_idx
=
0
for
_input
in
node
.
inputs
():
if
ignore_first
:
ignore_first
=
False
continue
# handle source node
if
_input
in
graph_inputs
:
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
elif
_input
in
output_remap
:
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
predecessor_node
=
output_remap
[
_input
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
None
src_node
=
node_index
[
predecessor_node
]
assert
isinstance
(
src_node
,
Node
)
else
:
predecessor_node
=
_input
.
node
()
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs
=
[
_output
for
_output
in
predecessor_node
.
outputs
()]
if
len
(
predecessor_outputs
)
==
1
:
idx
=
None
def
_add_edge
(
self
,
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
Parameters
----------
ir_graph : Graph
node : torch._C.Node
graph_inputs : List[torch._C.Value]
a list of a script graph's inputs
node_index : Dict
new_node : Node
newly created ir node corresponding to `node`
output_remap : Dict
ignore_first : bool
if it is true, skip the first input
"""
is_single_input
=
(
len
([
_input
for
_input
in
node
.
inputs
()])
-
(
1
if
ignore_first
else
0
))
==
1
new_node_input_idx
=
0
for
_input
in
node
.
inputs
():
if
ignore_first
:
ignore_first
=
False
continue
# handle source node
if
_input
in
graph_inputs
:
idx
=
graph_inputs
.
index
(
_input
)
src_node
=
ir_graph
.
input_node
src_node_idx
=
idx
elif
_input
in
output_remap
:
assert
output_remap
[
_input
].
kind
()
==
'aten::append'
predecessor_node
=
output_remap
[
_input
]
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
src_node_idx
=
None
src_node
=
node_index
[
predecessor_node
]
assert
isinstance
(
src_node
,
Node
)
else
:
idx
=
predecessor_outputs
.
index
(
_input
)
ir_predecessor_node
=
node_index
[
predecessor_node
]
src_node_idx
=
idx
assert
isinstance
(
ir_predecessor_node
,
Node
)
src_node
=
ir_predecessor_node
# handle destination node
dst_node
=
new_node
if
is_single_input
:
dst_node_idx
=
None
else
:
dst_node_idx
=
new_node_input_idx
# create edge
ir_graph
.
add_edge
(
head
=
(
src_node
,
src_node_idx
),
tail
=
(
dst_node
,
dst_node_idx
))
new_node_input_idx
+=
1
def
create_prim_constant_node
(
ir_graph
,
node
,
module_name
):
global
global_seq
attrs
=
{}
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
global_seq
),
node
.
kind
(),
attrs
)
return
new_node
def
handle_prim_attr_node
(
node
):
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
return
node
.
kind
(),
attrs
def
_remove_mangle
(
module_type_str
):
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
predecessor_node
=
_input
.
node
()
assert
predecessor_node
in
node_index
,
'predecessor node: {}'
.
format
(
predecessor_node
)
# find out the index of _input in the outputs of predecessor_node
predecessor_outputs
=
[
_output
for
_output
in
predecessor_node
.
outputs
()]
if
len
(
predecessor_outputs
)
==
1
:
idx
=
None
else
:
idx
=
predecessor_outputs
.
index
(
_input
)
ir_predecessor_node
=
node_index
[
predecessor_node
]
src_node_idx
=
idx
assert
isinstance
(
ir_predecessor_node
,
Node
)
src_node
=
ir_predecessor_node
# handle destination node
dst_node
=
new_node
if
is_single_input
:
dst_node_idx
=
None
else
:
dst_node_idx
=
new_node_input_idx
def
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
None
):
"""
Parameters
----------
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# build index of outputs of Node(s)
node_fanout
=
set
()
for
edge
in
ir_graph
.
edges
:
if
edge
.
head
.
id
not
in
node_fanout
:
node_fanout
.
add
(
edge
.
head
.
id
)
to_removes
=
[]
for
hidden_node
in
ir_graph
.
hidden_nodes
:
if
hidden_node
.
id
not
in
node_fanout
:
assert
isinstance
(
hidden_node
,
Node
)
if
targeted_type
is
None
:
to_removes
.
append
(
hidden_node
)
elif
hidden_node
.
operation
.
type
==
targeted_type
:
to_removes
.
append
(
hidden_node
)
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
# create edge
ir_graph
.
add_edge
(
head
=
(
src_node
,
src_node_idx
),
tail
=
(
dst_node
,
dst_node_idx
))
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
new_node_input_idx
+=
1
Returns
-------
dict
the mapping from graph node to our graph ir node
"""
# handle inputs
graph_inputs
=
[]
for
_input
in
sm_graph
.
inputs
():
if
_input
.
debugName
()
==
'self'
:
assert
_input
.
unique
()
==
0
continue
graph_inputs
.
append
(
_input
)
# TODO: add scope name
ir_graph
.
_add_input
(
_convert_name
(
_input
.
debugName
()))
node_index
=
{}
# graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap
=
{}
def
handle_if_condition
(
cond_tensor
):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
def
create_prim_constant_node
(
self
,
ir_graph
,
node
,
module_name
):
attrs
=
{}
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Constant
,
self
.
global_seq
),
node
.
kind
(),
attrs
)
return
new_node
generate the expression using recursive calls
def
handle_prim_attr_node
(
self
,
node
):
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
return
node
.
kind
(),
attrs
NOTE: do not support dynamic graph
"""
def
_generate_expr
(
tensor
):
if
tensor
.
node
().
kind
()
==
'prim::GetAttr'
:
return
f
'(
{
getattr
(
module
,
tensor
.
node
().
s
(
"name"
))
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__getitem__'
:
t
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
idx
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
t
}
[
{
idx
}
])'
elif
tensor
.
node
().
kind
()
==
'prim::Constant'
:
return
f
'
{
tensor
.
toIValue
()
}
'
elif
tensor
.
node
().
kind
()
==
'aten::eq'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
==
{
right
}
)'
else
:
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition'
)
expr
=
_generate_expr
(
cond_tensor
)
return
eval
(
expr
)
def
_remove_mangle
(
self
,
module_type_str
):
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
def
handle_if_node
(
nod
e
):
def
remove_unconnected_nodes
(
self
,
ir_graph
,
targeted_type
=
Non
e
):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
ir_graph : Graph
our ir graph representation
targeted_type : str
nodes with ```targeted_type``` will be removed from graph if their fanout is 0.
```None``` means removing all the nodes whose fanout is 0.
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs
=
[
i
for
i
in
node
.
inputs
()]
assert
len
(
inputs
)
==
1
cond
=
handle_if_condition
(
inputs
[
0
])
chosen_block
=
0
if
cond
else
1
blocks
=
[
block
for
block
in
node
.
blocks
()]
assert
len
(
blocks
)
==
2
last_block_node
=
None
for
node
in
blocks
[
chosen_block
].
nodes
():
last_block_node
=
handle_single_node
(
node
)
return
last_block_node
def
handle_single_node
(
node
):
# build index of outputs of Node(s)
node_fanout
=
set
()
for
edge
in
ir_graph
.
edges
:
if
edge
.
head
.
id
not
in
node_fanout
:
node_fanout
.
add
(
edge
.
head
.
id
)
to_removes
=
[]
for
hidden_node
in
ir_graph
.
hidden_nodes
:
if
hidden_node
.
id
not
in
node_fanout
:
assert
isinstance
(
hidden_node
,
Node
)
if
targeted_type
is
None
:
to_removes
.
append
(
hidden_node
)
elif
hidden_node
.
operation
.
type
==
targeted_type
:
to_removes
.
append
(
hidden_node
)
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
self
,
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
script_module : torch.jit.RecursiveScriptModule
the torch script of ```module```
sm_graph : torch._C.Graph
the graph in torch script
module : nn.Module
the targeted pytorch module
module_name : str
```module```'s name
ir_model : Model
the whole graph ir
ir_graph : Graph
the graph ir of ```module```
Returns
-------
Node
the
created node ir
dict
the
mapping from graph node to our graph ir node
"""
global
global_seq
if
node
.
kind
()
==
'prim::CallMethod'
:
# get and handle the first input, which should be an nn.Module
assert
node
.
hasAttribute
(
'name'
)
if
node
.
s
(
'name'
)
==
'forward'
:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str
=
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule
=
node
.
inputsAt
(
0
).
node
()
assert
submodule
.
kind
()
==
'prim::GetAttr'
assert
submodule
.
hasAttribute
(
'name'
)
submodule_name
=
submodule
.
s
(
'name'
)
if
submodule
.
inputsAt
(
0
).
debugName
()
==
'self'
:
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
# handle inputs
graph_inputs
=
[]
for
_input
in
sm_graph
.
inputs
():
if
_input
.
debugName
()
==
'self'
:
assert
_input
.
unique
()
==
0
continue
graph_inputs
.
append
(
_input
)
# TODO: add scope name
ir_graph
.
_add_input
(
_convert_name
(
_input
.
debugName
()))
node_index
=
{}
# graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16)
# %out.1 is updated, and %17 is None
# we add output to this type of node and connect it to the following node which uses %out.1
# key: tensor (%out.1), value: node (this node)
output_remap
=
{}
def
handle_if_condition
(
cond_tensor
):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def
_generate_expr
(
tensor
):
if
tensor
.
node
().
kind
()
==
'prim::GetAttr'
:
return
f
'(
{
getattr
(
module
,
tensor
.
node
().
s
(
"name"
))
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__getitem__'
:
t
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
idx
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
t
}
[
{
idx
}
])'
elif
tensor
.
node
().
kind
()
==
'prim::Constant'
:
return
f
'
{
tensor
.
toIValue
()
}
'
elif
tensor
.
node
().
kind
()
==
'aten::eq'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
==
{
right
}
)'
else
:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
assert
predecessor
.
kind
()
==
'prim::GetAttr'
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
predecessor_name
=
predecessor
.
s
(
'name'
)
# FIXME: exchange
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition'
)
expr
=
_generate_expr
(
cond_tensor
)
return
eval
(
expr
)
def
handle_if_node
(
node
):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
# only deal with input of prim::If is constant or attribute for now
# will support constant expression in future
inputs
=
[
i
for
i
in
node
.
inputs
()]
assert
len
(
inputs
)
==
1
cond
=
handle_if_condition
(
inputs
[
0
])
chosen_block
=
0
if
cond
else
1
blocks
=
[
block
for
block
in
node
.
blocks
()]
assert
len
(
blocks
)
==
2
last_block_node
=
None
for
node
in
blocks
[
chosen_block
].
nodes
():
last_block_node
=
handle_single_node
(
node
)
return
last_block_node
def
handle_single_node
(
node
):
"""
Parameters
----------
node : torch._C.Node
the node from TorchScript graph
Returns
-------
Node
the created node ir
"""
if
node
.
kind
()
==
'prim::CallMethod'
:
# get and handle the first input, which should be an nn.Module
assert
node
.
hasAttribute
(
'name'
)
if
node
.
s
(
'name'
)
==
'forward'
:
# node.inputsAt(0).type() is <class 'torch._C.ClassType'>
submodule_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
submodule
=
node
.
inputsAt
(
0
).
node
()
assert
submodule
.
kind
()
==
'prim::GetAttr'
assert
submodule
.
hasAttribute
(
'name'
)
submodule_name
=
submodule
.
s
(
'name'
)
if
submodule
.
inputsAt
(
0
).
debugName
()
==
'self'
:
# module is usually instantiated in __init__.
# when calling a module in forward,
# prim::GetAttr is used to obtain the module in torch script.
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
# TODO: match subgraph with maintained graphs
# build cell
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
# %s1.4 : Tensor = prim::CallMethod[name="forward"](%10, %4, %4)
if
submodule
.
inputsAt
(
0
).
type
().
name
()
==
'ModuleList'
:
# handle ModuleList
predecessor
=
submodule
.
inputsAt
(
0
).
node
()
assert
predecessor
.
kind
()
==
'prim::GetAttr'
assert
predecessor
.
hasAttribute
(
'name'
)
assert
predecessor
.
inputsAt
(
0
).
debugName
()
==
'self'
predecessor_name
=
predecessor
.
s
(
'name'
)
# FIXME: exchange
submodule_full_name
=
build_full_name
(
module_name
,
[
submodule_name
,
predecessor_name
])
predecessor_obj
=
getattr
(
module
,
predecessor_name
)
submodule_obj
=
getattr
(
predecessor_obj
,
submodule_name
)
subgraph
,
sub_m_attrs
=
self
.
convert_module
(
script_module
.
_modules
[
predecessor_name
].
_modules
[
submodule_name
],
submodule_obj
,
submodule_full_name
,
ir_model
)
else
:
raise
RuntimeError
(
'Unsupported module case: {}'
.
format
(
submodule
.
inputsAt
(
0
).
type
().
str
()))
# TODO: match subgraph with maintained graphs
# build cell
if
subgraph
is
None
:
# if we do not parse this module's graph, we create Node for this module
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
else
:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
node_index
[
node
]
=
subcell
# connect the cell into graph
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
else
:
# Graph already created, create Cell for it
new_cell
=
Cell
(
cell_name
=
submodule_full_name
,
parameters
=
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
new_cell
)
node_index
[
node
]
=
subcell
# connect the cell into graph
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
subcell
,
output_remap
,
ignore_first
=
True
)
raise
RuntimeError
(
'unsupported CallMethod {}'
.
format
(
node
.
s
(
'name'
)))
elif
node
.
kind
()
==
'prim::CallFunction'
:
func_type_str
=
self
.
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func
=
node
.
inputsAt
(
0
).
node
()
assert
func
.
kind
()
==
'prim::Constant'
assert
func
.
hasAttribute
(
'name'
)
func_name
=
func
.
s
(
'name'
)
# create node for func
self
.
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
self
.
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
node_index
[
node
]
=
func_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
new_node
=
self
.
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
self
.
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
self
.
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
self
.
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
self
.
handle_prim_attr_node
(
node
)
self
.
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
self
.
global_seq
),
node_type
,
attrs
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::If'
:
last_block_node
=
handle_if_node
(
node
)
# last_block_node is None means no node in the branch block
node_index
[
node
]
=
last_block_node
elif
node
.
kind
()
==
'prim::Loop'
:
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise
RuntimeError
(
'Loop has not been supported yet!'
)
else
:
raise
RuntimeError
(
'unsupported CallMethod {}'
.
format
(
node
.
s
(
'name'
)))
elif
node
.
kind
()
==
'prim::CallFunction'
:
func_type_str
=
_remove_mangle
(
node
.
inputsAt
(
0
).
type
().
str
())
func
=
node
.
inputsAt
(
0
).
node
()
assert
func
.
kind
()
==
'prim::Constant'
assert
func
.
hasAttribute
(
'name'
)
func_name
=
func
.
s
(
'name'
)
# create node for func
global_seq
+=
1
func_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
func_name
,
global_seq
),
'{}.{}'
.
format
(
func_type_str
,
func_name
))
node_index
[
node
]
=
func_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
func_node
,
output_remap
,
ignore_first
=
True
)
elif
node
.
kind
()
==
'prim::Constant'
:
new_node
=
create_prim_constant_node
(
ir_graph
,
node
,
module_name
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
ListConstruct
,
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
handle_prim_attr_node
(
node
)
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
OpTypeName
.
Attr
,
global_seq
),
node_type
,
attrs
)
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::If'
:
last_block_node
=
handle_if_node
(
node
)
# last_block_node is None means no node in the branch block
node_index
[
node
]
=
last_block_node
elif
node
.
kind
()
==
'prim::Loop'
:
# refer to https://gist.github.com/liuzhe-lz/90c35d9dd6fd7f3f32544940151ab186
raise
RuntimeError
(
'Loop has not been supported yet!'
)
else
:
raise
RuntimeError
(
'Unsupported kind: {}'
.
format
(
node
.
kind
()))
return
node_index
[
node
]
for
node
in
sm_graph
.
nodes
():
handle_single_node
(
node
)
return
node_index
def
merge_aten_slices
(
ir_graph
):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes
=
[]
has_slice_node
=
False
for
node
in
ir_graph
.
hidden_nodes
:
if
node
.
operation
.
type
==
'aten::slice'
:
has_slice_node
=
True
for
pred
in
node
.
predecessors
:
if
pred
.
operation
.
type
not
in
[
'aten::slice'
,
'prim::Constant'
]:
head_slice_nodes
.
append
(
node
)
break
if
has_slice_node
:
assert
head_slice_nodes
for
head_node
in
head_slice_nodes
:
slot
=
0
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
if
len
(
head_node
.
incoming_edges
)
==
4
:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert
len
(
head_node
.
incoming_edges
)
==
5
for
edge
in
head_node
.
incoming_edges
:
edge
.
tail
=
new_slice_node
slot
+=
5
node
=
head_node
while
len
(
node
.
successors
)
==
1
and
node
.
successors
[
0
].
operation
.
type
==
'aten::slice'
:
suc_node
=
node
.
successors
[
0
]
assert
len
(
suc_node
.
incoming_edges
)
==
5
for
edge
in
suc_node
.
incoming_edges
:
if
edge
.
tail_slot
==
0
:
edge
.
remove
()
else
:
edge
.
tail
=
new_slice_node
edge
.
tail_slot
=
slot
+
edge
.
tail_slot
-
1
slot
+=
4
ir_graph
.
hidden_nodes
.
remove
(
node
)
node
=
suc_node
raise
RuntimeError
(
'Unsupported kind: {}'
.
format
(
node
.
kind
()))
for
edge
in
node
.
outgoing_edges
:
edge
.
head
=
new_slice_node
ir_graph
.
hidden_nodes
.
remove
(
node
)
return
node_index
[
node
]
for
node
in
sm_graph
.
nodes
():
handle_single_node
(
node
)
def
refine_graph
(
ir_graph
):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::Constant'
)
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
merge_aten_slices
(
ir_graph
)
return
node_index
def
merge_aten_slices
(
self
,
ir_graph
):
"""
if there is aten::slice node, merge the consecutive ones together.
```x[:, :, 1:, 1:]``` in python code will be converted into 4 node in torch script,
each node has 5 inputs: tensor, dim, x, y, z (i.e., x:y:z)
"""
head_slice_nodes
=
[]
has_slice_node
=
False
for
node
in
ir_graph
.
hidden_nodes
:
if
node
.
operation
.
type
==
'aten::slice'
:
has_slice_node
=
True
for
pred
in
node
.
predecessors
:
if
pred
.
operation
.
type
not
in
[
'aten::slice'
,
'prim::Constant'
]:
head_slice_nodes
.
append
(
node
)
break
if
has_slice_node
:
assert
head_slice_nodes
for
head_node
in
head_slice_nodes
:
slot
=
0
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
if
len
(
head_node
.
incoming_edges
)
==
4
:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert
len
(
head_node
.
incoming_edges
)
==
5
for
edge
in
head_node
.
incoming_edges
:
edge
.
tail
=
new_slice_node
slot
+=
5
node
=
head_node
while
len
(
node
.
successors
)
==
1
and
node
.
successors
[
0
].
operation
.
type
==
'aten::slice'
:
suc_node
=
node
.
successors
[
0
]
assert
len
(
suc_node
.
incoming_edges
)
==
5
for
edge
in
suc_node
.
incoming_edges
:
if
edge
.
tail_slot
==
0
:
edge
.
remove
()
else
:
edge
.
tail
=
new_slice_node
edge
.
tail_slot
=
slot
+
edge
.
tail_slot
-
1
slot
+=
4
ir_graph
.
hidden_nodes
.
remove
(
node
)
node
=
suc_node
for
edge
in
node
.
outgoing_edges
:
edge
.
head
=
new_slice_node
ir_graph
.
hidden_nodes
.
remove
(
node
)
def
refine_graph
(
self
,
ir_graph
):
"""
Do the following process to simplify graph:
1. remove unconnected constant node
2. remove unconnected getattr node
"""
# some constant is not used, for example, function name as prim::Constant
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::Constant'
)
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
self
.
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
self
,
module
):
m_attrs
=
{}
candidates
=
module
.
op_candidates
choices
=
[]
for
cand
in
candidates
:
assert
id
(
cand
)
in
self
.
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
self
.
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
self
.
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
_handle_inputchoice
(
self
,
module
):
m_attrs
=
{}
m_attrs
[
'n_candidates'
]
=
module
.
n_candidates
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
def
_handle_layerchoice
(
module
):
global
modules_arg
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
m_attrs
=
{}
candidates
=
module
.
candidate_ops
choices
=
[]
for
cand
in
candidates
:
assert
id
(
cand
)
in
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
m_attrs
=
None
if
original_type_name
in
MODULE_EXCEPT_LIST
:
pass
# do nothing
elif
original_type_name
==
OpTypeName
.
LayerChoice
:
m_attrs
=
self
.
_handle_layerchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
self
.
_handle_inputchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
elif
original_type_name
in
torch
.
nn
.
__dict__
:
# this is a basic module from pytorch, no need to parse its graph
assert
id
(
module
)
in
self
.
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
elif
id
(
module
)
in
self
.
modules_arg
:
# this module is marked as blackbox, won't continue to parse
m_attrs
=
self
.
modules_arg
[
id
(
module
)]
if
m_attrs
is
not
None
:
return
None
,
m_attrs
# handle TorchScript graph
sm_graph
=
script_module
.
graph
self
.
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
self
.
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
# handle graph nodes
node_index
=
self
.
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
)
# handle graph outputs
for
_output
in
sm_graph
.
outputs
():
ir_graph
.
_add_output
(
_convert_name
(
_output
.
debugName
()))
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
if
len
(
predecessor_node_outputs
)
==
1
:
src_node_idx
=
None
else
:
src_node_idx
=
predecessor_node_outputs
.
index
(
_output
)
ir_graph
.
add_edge
(
head
=
(
node_index
[
_output
.
node
()],
src_node_idx
),
tail
=
(
ir_graph
.
output_node
,
None
))
def
_handle_inputchoice
(
module
):
m_attrs
=
{}
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
m_attrs
[
'reduction'
]
=
module
.
reduction
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
self
.
refine_graph
(
ir_graph
)
ir_graph
.
_register
()
def
convert_module
(
script_module
,
module
,
module_name
,
ir_model
):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
return
ir_graph
,
{}
Parameters
----------
script_module : torch.jit.RecursiveScriptModule
the script module of ```module``` obtained with torch.jit.script
module : nn.Module
the targeted module instance
module_name : str
the constructed name space of ```module```
ir_model : Model
the whole graph ir
Returns
-------
Graph
the built graph ir from module, ```None``` means do not further parse the module
dict
the input arguments of this module
"""
global
global_graph_id
global
modules_arg
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
if
original_type_name
==
OpTypeName
.
LayerChoice
:
m_attrs
=
_handle_layerchoice
(
module
)
return
None
,
m_attrs
if
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
_handle_inputchoice
(
module
)
return
None
,
m_attrs
if
original_type_name
==
OpTypeName
.
Placeholder
:
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
if
original_type_name
in
torch
.
nn
.
__dict__
and
original_type_name
not
in
MODULE_EXCEPT_LIST
:
# this is a basic module from pytorch, no need to parse its graph
assert
id
(
module
)
in
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
# handle TorchScript graph
sm_graph
=
script_module
.
graph
global_graph_id
+=
1
ir_graph
=
Graph
(
model
=
ir_model
,
graph_id
=
global_graph_id
,
name
=
module_name
,
_internal
=
True
)
# handle graph nodes
node_index
=
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
)
# handle graph outputs
for
_output
in
sm_graph
.
outputs
():
ir_graph
.
_add_output
(
_convert_name
(
_output
.
debugName
()))
predecessor_node_outputs
=
[
o
for
o
in
_output
.
node
().
outputs
()]
if
len
(
predecessor_node_outputs
)
==
1
:
src_node_idx
=
None
else
:
src_node_idx
=
predecessor_node_outputs
.
index
(
_output
)
ir_graph
.
add_edge
(
head
=
(
node_index
[
_output
.
node
()],
src_node_idx
),
tail
=
(
ir_graph
.
output_node
,
None
))
refine_graph
(
ir_graph
)
ir_graph
.
_register
()
if
id
(
module
)
not
in
modules_arg
:
raise
RuntimeError
(
f
'
{
original_type_name
}
arguments are not recorded,
\
you might have forgotten to decorate this class with @register_module()'
)
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return
ir_graph
,
{}
def
convert_to_graph
(
script_module
,
module
,
recorded_modules_arg
):
def
convert_to_graph
(
script_module
,
module
):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
...
...
@@ -526,18 +510,15 @@ def convert_to_graph(script_module, module, recorded_modules_arg):
the script module obtained with torch.jit.script
module : nn.Module
the targeted module instance
recorded_modules_arg : dict
the recorded args of each module in the module
Returns
-------
Model
the constructed IR model
"""
global
modules_arg
modules_arg
=
recorded_modules_arg
model
=
Model
(
_internal
=
True
)
module_name
=
'_model'
convert_module
(
script_module
,
module
,
module_name
,
model
)
GraphConverter
().
convert_module
(
script_module
,
module
,
module_name
,
model
)
return
model
nni/retiarii/converter/op_types.py
View file @
4784cc6c
...
...
@@ -30,6 +30,10 @@ BasicOpsPT = {
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
,
'aten::eq'
:
'Eq'
,
'aten::Bool'
:
'Bool'
,
'aten::empty'
:
'Empty'
,
'aten::zeros'
:
'Zeros'
,
'aten::chunk'
:
'Chunk'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
...
...
nni/retiarii/execution/api.py
View file @
4784cc6c
import
time
import
os
from
typing
import
List
from
..graph
import
Model
,
ModelStatus
from
.base
import
BaseExecutionEngine
from
.cgo_engine
import
CGOExecutionEngine
from
.interface
import
AbstractExecutionEngine
,
WorkerInfo
from
.interface
import
AbstractExecutionEngine
from
.listener
import
DefaultListener
_execution_engine
=
None
_default_listener
=
None
__all__
=
[
'get_execution_engine'
,
'get_and_register_default_listener'
,
'submit_models'
,
'wait_models'
,
'query_available_resources'
]
'submit_models'
,
'wait_models'
,
'query_available_resources'
,
'set_execution_engine'
,
'is_stopped_exec'
]
def
set_execution_engine
(
engine
)
->
None
:
global
_execution_engine
if
_execution_engine
is
None
:
_execution_engine
=
engine
else
:
raise
RuntimeError
(
'execution engine is already set'
)
def
get_execution_engine
()
->
Base
ExecutionEngine
:
def
get_execution_engine
()
->
Abstract
ExecutionEngine
:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
"""
global
_execution_engine
if
_execution_engine
is
None
:
if
os
.
environ
.
get
(
'CGO'
)
==
'true'
:
_execution_engine
=
CGOExecutionEngine
()
else
:
_execution_engine
=
BaseExecutionEngine
()
return
_execution_engine
...
...
@@ -51,6 +49,11 @@ def wait_models(*models: Model) -> None:
break
def
query_available_resources
()
->
List
[
WorkerInfo
]:
listener
=
get_and_register_default_listener
(
get_execution_engine
())
return
listener
.
resources
def
query_available_resources
()
->
int
:
engine
=
get_execution_engine
()
resources
=
engine
.
query_available_resource
()
return
resources
if
isinstance
(
resources
,
int
)
else
len
(
resources
)
def
is_stopped_exec
(
model
:
Model
)
->
bool
:
return
model
.
status
in
(
ModelStatus
.
Trained
,
ModelStatus
.
Failed
)
nni/retiarii/execution/base.py
View file @
4784cc6c
import
logging
import
os
import
random
import
string
from
typing
import
Dict
,
Any
,
List
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..integration
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
..integration
_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -29,7 +32,7 @@ class BaseGraphData:
class
BaseExecutionEngine
(
AbstractExecutionEngine
):
"""
The execution engine with no optimization at all.
Resource management is
yet to be
implemented.
Resource management is implemented
in this class
.
"""
def
__init__
(
self
)
->
None
:
...
...
@@ -50,6 +53,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
resources
=
0
def
submit_models
(
self
,
*
models
:
Model
)
->
None
:
for
model
in
models
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
...
...
@@ -60,17 +65,14 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self
.
_listeners
.
append
(
listener
)
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
for
listener
in
self
.
_listeners
:
_logger
.
warning
(
'resources: %s'
,
listener
.
resources
)
if
not
listener
.
has_available_resource
():
_logger
.
warning
(
'There is no available resource, but trial is submitted.'
)
listener
.
on_resource_used
(
1
)
_logger
.
warning
(
'on_resource_used: %s'
,
listener
.
resources
)
if
self
.
resources
<=
0
:
_logger
.
warning
(
'There is no available resource, but trial is submitted.'
)
self
.
resources
-=
1
_logger
.
info
(
'on_resource_used: %d'
,
self
.
resources
)
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_available
(
1
*
num_trials
)
_logger
.
warning
(
'on_resource_available: %s'
,
listener
.
resources
)
self
.
resources
+=
num_trials
_logger
.
info
(
'on_resource_available: %d'
,
self
.
resources
)
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
model
=
self
.
_running_models
[
trial_id
]
...
...
@@ -93,8 +95,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
model
,
metrics
)
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]
:
r
aise
NotImplementedError
# move the method from listener to here?
def
query_available_resource
(
self
)
->
int
:
r
eturn
self
.
resources
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
...
...
@@ -102,9 +104,12 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer.
"""
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
with
open
(
'_generated_model.py'
,
'w'
)
as
f
:
random_str
=
''
.
join
(
random
.
choice
(
string
.
ascii_uppercase
+
string
.
digits
)
for
_
in
range
(
6
))
file_name
=
f
'_generated_model_
{
random_str
}
.py'
with
open
(
file_name
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
model_cls
=
utils
.
import_
(
'_generated_model._model'
)
model_cls
=
utils
.
import_
(
f
'_generated_model
_
{
random_str
}
._model'
)
trainer_instance
=
trainer_cls
(
model
=
model_cls
(),
**
graph_data
.
training_kwargs
)
trainer_instance
.
fit
()
os
.
remove
(
file_name
)
\ No newline at end of file
nni/retiarii/execution/cgo_engine.py
View file @
4784cc6c
...
...
@@ -4,7 +4,7 @@ from typing import List, Dict, Tuple
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..integration
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
..integration
_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
PhysicalDevice
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
...
...
nni/retiarii/execution/interface.py
View file @
4784cc6c
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
typing
import
Any
,
NewType
,
List
from
typing
import
Any
,
NewType
,
List
,
Union
from
..graph
import
Model
,
MetricData
...
...
@@ -59,13 +59,6 @@ class AbstractGraphListener(ABC):
"""
pass
@
abstractmethod
def
on_resource_available
(
self
,
resources
:
List
[
WorkerInfo
])
->
None
:
"""
Reports when a worker becomes idle.
"""
pass
class
AbstractExecutionEngine
(
ABC
):
"""
...
...
@@ -109,7 +102,7 @@ class AbstractExecutionEngine(ABC):
raise
NotImplementedError
@
abstractmethod
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
def
query_available_resource
(
self
)
->
Union
[
List
[
WorkerInfo
]
,
int
]
:
"""
Returns information of all idle workers.
If no details are available, this may returns a list of "empty" objects, reporting the number of idle workers.
...
...
nni/retiarii/execution/listener.py
View file @
4784cc6c
...
...
@@ -3,11 +3,6 @@ from .interface import MetricData, AbstractGraphListener
class
DefaultListener
(
AbstractGraphListener
):
def
__init__
(
self
):
self
.
resources
:
int
=
0
# simply resource count
def
has_available_resource
(
self
)
->
bool
:
return
self
.
resources
>
0
def
on_metric
(
self
,
model
:
Model
,
metric
:
MetricData
)
->
None
:
model
.
metric
=
metric
...
...
@@ -20,9 +15,3 @@ class DefaultListener(AbstractGraphListener):
model
.
status
=
ModelStatus
.
Trained
else
:
model
.
status
=
ModelStatus
.
Failed
def
on_resource_available
(
self
,
resources
:
int
)
->
None
:
self
.
resources
+=
resources
def
on_resource_used
(
self
,
resources
:
int
)
->
None
:
self
.
resources
-=
resources
nni/retiarii/experiment.py
View file @
4784cc6c
import
logging
import
time
from
dataclasses
import
dataclass
from
pathlib
import
Path
...
...
@@ -7,20 +6,24 @@ from subprocess import Popen
from
threading
import
Thread
from
typing
import
Any
,
Optional
from
..experiment
import
Experiment
,
TrainingServiceConfig
,
launcher
,
rest
from
..experiment
import
Experiment
,
TrainingServiceConfig
from
..experiment.config.base
import
ConfigBase
,
PathLike
from
..experiment.config
import
util
from
..experiment.pipe
import
Pipe
from
.graph
import
Model
from
.utils
import
get_records
from
.integration
import
RetiariiAdvisor
from
.converter
import
convert_to_graph
from
.mutator
import
Mutator
,
LayerChoiceMutator
,
InputChoiceMutator
from
.trainer.interface
import
BaseTrainer
from
.trainer.interface
import
BaseTrainer
,
BaseOneShotTrainer
from
.strategies.strategy
import
BaseStrategy
from
.trainer.pytorch
import
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
_logger
=
logging
.
getLogger
(
__name__
)
OneShotTrainers
=
(
DartsTrainer
,
EnasTrainer
,
ProxylessTrainer
,
RandomTrainer
,
SinglePathTrainer
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
...
...
@@ -43,7 +46,7 @@ class RetiariiExeConfig(ConfigBase):
super
().
__init__
(
**
kwargs
)
if
training_service_platform
is
not
None
:
assert
'training_service'
not
in
kwargs
self
.
training_service
=
util
.
training_service_config_factory
(
training_service_platform
)
self
.
training_service
=
util
.
training_service_config_factory
(
platform
=
training_service_platform
)
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
super
().
validate
()
...
...
@@ -76,7 +79,7 @@ _validation_rules = {
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
Model
,
trainer
:
BaseTrainer
,
applied_mutators
:
Mutator
,
strategy
:
BaseStrategy
):
applied_mutators
:
Mutator
=
None
,
strategy
:
BaseStrategy
=
None
):
self
.
config
:
RetiariiExeConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
...
...
@@ -87,6 +90,7 @@ class RetiariiExperiment(Experiment):
self
.
recorded_module_args
=
get_records
()
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
...
...
@@ -103,7 +107,10 @@ class RetiariiExperiment(Experiment):
mutator
=
LayerChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'choices'
])
applied_mutators
.
append
(
mutator
)
for
node
in
ic_nodes
:
mutator
=
InputChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'n_chosen'
])
mutator
=
InputChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'n_candidates'
],
node
.
operation
.
parameters
[
'n_chosen'
],
node
.
operation
.
parameters
[
'reduction'
])
applied_mutators
.
append
(
mutator
)
return
applied_mutators
...
...
@@ -114,14 +121,17 @@ class RetiariiExperiment(Experiment):
except
Exception
as
e
:
_logger
.
error
(
'Your base model cannot be parsed by torch.jit.script, please fix the following error:'
)
raise
e
base_model
=
convert_to_graph
(
script_module
,
self
.
base_model
,
self
.
recorded_module_args
)
base_model
_ir
=
convert_to_graph
(
script_module
,
self
.
base_model
)
assert
id
(
self
.
trainer
)
in
self
.
recorded_module_args
trainer_config
=
self
.
recorded_module_args
[
id
(
self
.
trainer
)]
base_model
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
recorded_module_args
=
get_records
()
if
id
(
self
.
trainer
)
not
in
recorded_module_args
:
raise
KeyError
(
'Your trainer is not found in registered classes. You might have forgotten to
\
register your customized trainer with @register_trainer decorator.'
)
trainer_config
=
recorded_module_args
[
id
(
self
.
trainer
)]
base_model_ir
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
# handle inline mutations
mutators
=
self
.
_process_inline_mutation
(
base_model
)
mutators
=
self
.
_process_inline_mutation
(
base_model
_ir
)
if
mutators
is
not
None
and
self
.
applied_mutators
:
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators,
\
do not use mutators when you use LayerChoice/InputChoice'
)
...
...
@@ -129,10 +139,10 @@ class RetiariiExperiment(Experiment):
self
.
applied_mutators
=
mutators
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
)).
start
()
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
_ir
,
self
.
applied_mutators
)).
start
()
_logger
.
info
(
'Strategy started!'
)
def
start
(
self
,
config
:
RetiariiExeConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
Start the experiment in background.
This method will raise exception on failure.
...
...
@@ -144,54 +154,37 @@ class RetiariiExperiment(Experiment):
debug
Whether to start in debug mode.
"""
# FIXME:
if
debug
:
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
config
,
port
,
debug
)
assert
self
.
_proc
is
not
None
assert
self
.
_pipe
is
not
None
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread
(
target
=
self
.
_dispatcher
.
run
).
start
()
super
().
start
(
port
,
debug
)
self
.
_start_strategy
()
# TODO: register experiment management metadata
def
_create_dispatcher
(
self
):
return
self
.
_dispatcher
def
stop
(
self
)
->
None
:
"""
Stop background experiment.
"""
self
.
_proc
.
kill
()
self
.
_pipe
.
close
()
self
.
port
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
def
run
(
self
,
config
:
RetiariiExeConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
def
run
(
self
,
config
:
RetiariiExeConfig
=
None
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self
.
config
=
config
self
.
start
(
config
,
port
,
debug
)
try
:
while
True
:
time
.
sleep
(
10
)
status
=
self
.
get_status
()
# TODO: double check the status
if
status
in
[
'ERROR'
,
'STOPPED'
,
'NO_MORE_TRIAL'
]:
return
status
finally
:
self
.
stop
()
def
get_status
(
self
)
->
str
:
if
self
.
port
is
None
:
raise
RuntimeError
(
'Experiment is not running'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
return
resp
[
'status'
]
if
isinstance
(
self
.
trainer
,
OneShotTrainers
):
self
.
trainer
.
fit
()
else
:
assert
config
is
not
None
,
'You are using classic search mode, config cannot be None!'
self
.
config
=
config
super
().
run
(
port
,
debug
)
def
export_top_models
(
self
,
top_n
:
int
=
1
):
"""
export several top performing models
"""
if
top_n
!=
1
:
_logger
.
warning
(
'Only support top_n is 1 for now.'
)
if
isinstance
(
self
.
trainer
,
BaseOneShotTrainer
):
return
self
.
trainer
.
export
()
else
:
_logger
.
info
(
'For this experiment, you can find out the best one from WebUI.'
)
def
retrain_model
(
self
,
model
):
"""
this function retrains the exported model, and test it to output test accuracy
"""
raise
NotImplementedError
nni/retiarii/graph.py
View file @
4784cc6c
...
...
@@ -594,10 +594,10 @@ class Edge:
Example forward code snippet:
```
a, b, c = split(x)
p = concat(a, c)
q = sum(b, p)
z = relu(q)
a, b, c = split(x)
p = concat(a, c)
q = sum(b, p)
z = relu(q)
```
Edges in above snippet:
...
...
nni/retiarii/integration.py
View file @
4784cc6c
import
logging
import
os
from
typing
import
Any
,
Callable
import
json_tricks
import
nni
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
from
.graph
import
MetricData
from
.execution.base
import
BaseExecutionEngine
from
.execution.cgo_engine
import
CGOExecutionEngine
from
.execution.api
import
set_execution_engine
from
.integration_api
import
register_advisor
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -55,6 +59,15 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
parameters_count
=
0
engine
=
self
.
_create_execution_engine
()
set_execution_engine
(
engine
)
def
_create_execution_engine
(
self
):
if
os
.
environ
.
get
(
'CGO'
)
==
'true'
:
return
CGOExecutionEngine
()
else
:
return
BaseExecutionEngine
()
def
handle_initialize
(
self
,
data
):
"""callback for initializing the advisor
Parameters
...
...
@@ -126,34 +139,3 @@ class RetiariiAdvisor(MsgDispatcherBase):
else
:
return
value
return
value
_advisor
:
RetiariiAdvisor
=
None
def
get_advisor
()
->
RetiariiAdvisor
:
global
_advisor
assert
_advisor
is
not
None
return
_advisor
def
register_advisor
(
advisor
:
RetiariiAdvisor
):
global
_advisor
assert
_advisor
is
None
_advisor
=
advisor
def
send_trial
(
parameters
:
dict
)
->
int
:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return
get_advisor
().
send_trial
(
parameters
)
def
receive_trial_parameters
()
->
dict
:
"""
Received a new trial. Executed on trial end.
"""
params
=
nni
.
get_next_parameter
()
return
params
nni/retiarii/integration_api.py
0 → 100644
View file @
4784cc6c
from
typing
import
NewType
,
Any
import
nni
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor
=
NewType
(
'RetiariiAdvisor'
,
Any
)
_advisor
:
'RetiariiAdvisor'
=
None
def
get_advisor
()
->
'RetiariiAdvisor'
:
global
_advisor
assert
_advisor
is
not
None
return
_advisor
def
register_advisor
(
advisor
:
'RetiariiAdvisor'
):
global
_advisor
assert
_advisor
is
None
_advisor
=
advisor
def
send_trial
(
parameters
:
dict
)
->
int
:
"""
Send a new trial. Executed on tuner end.
Return a ID that is the unique identifier for this trial.
"""
return
get_advisor
().
send_trial
(
parameters
)
def
receive_trial_parameters
()
->
dict
:
"""
Received a new trial. Executed on trial end.
"""
params
=
nni
.
get_next_parameter
()
return
params
nni/retiarii/mutator.py
View file @
4784cc6c
...
...
@@ -28,8 +28,10 @@ class Mutator:
"""
Mutates graphs in model to generate new model.
`Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model.
...
...
@@ -104,6 +106,7 @@ class _RecorderSampler(Sampler):
self
.
recorded_candidates
.
append
(
candidates
)
return
candidates
[
0
]
# the following is for inline mutation
...
...
@@ -122,14 +125,16 @@ class LayerChoiceMutator(Mutator):
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
n_c
hosen
:
int
):
def
__init__
(
self
,
node_name
:
str
,
n_c
andidates
:
int
,
n_chosen
:
int
,
reduction
:
str
):
super
().
__init__
()
self
.
node_name
=
node_name
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
def
mutate
(
self
,
model
):
target
=
model
.
get_node_by_name
(
self
.
node_name
)
candidates
=
[
i
for
i
in
range
(
self
.
n_c
hosen
)]
chosen
=
self
.
choice
(
candidates
)
candidates
=
[
i
for
i
in
range
(
self
.
n_c
andidates
)]
chosen
=
[
self
.
choice
(
candidates
)
for
_
in
range
(
self
.
n_chosen
)]
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs'
,
{
'chosen'
:
chosen
})
{
'chosen'
:
chosen
,
'reduction'
:
self
.
reduction
})
nni/retiarii/nn/pytorch/nn.py
View file @
4784cc6c
import
inspect
import
logging
from
typing
import
Any
,
List
import
torch
import
torch.nn
as
nn
from
...utils
import
add_record
from
...utils
import
add_record
,
blackbox_module
,
uid
,
version_larger_equal
_logger
=
logging
.
getLogger
(
__name__
)
# NOTE: support pytorch version >= 1.5.0
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
...
...
@@ -29,18 +30,24 @@ __all__ = [
'ConstantPad3d'
,
'Bilinear'
,
'CosineSimilarity'
,
'Unfold'
,
'Fold'
,
'AdaptiveLogSoftmaxWithLoss'
,
'TransformerEncoder'
,
'TransformerDecoder'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten'
,
'Hardsigmoid'
,
'Hardswish'
'Flatten'
,
'Hardsigmoid'
]
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
__all__
.
append
(
'Hardswish'
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
__all__
.
extend
([
'Unflatten'
,
'SiLU'
,
'TripletMarginWithDistanceLoss'
])
class
LayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
op_candidates
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
super
(
LayerChoice
,
self
).
__init__
()
self
.
candidate_ops
=
op_candidates
self
.
label
=
key
self
.
op_candidates
=
op_candidates
self
.
label
=
key
if
key
is
not
None
else
f
'layerchoice_
{
uid
()
}
'
self
.
key
=
self
.
label
# deprecated, for backward compatibility
for
i
,
module
in
enumerate
(
op_candidates
):
# deprecated, for backward compatibility
self
.
add_module
(
str
(
i
),
module
)
if
reduction
or
return_mask
:
_logger
.
warning
(
'input arguments `reduction` and `return_mask` are deprecated!'
)
...
...
@@ -52,10 +59,12 @@ class InputChoice(nn.Module):
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
1
,
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
super
(
InputChoice
,
self
).
__init__
()
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
label
=
key
if
n_candidates
or
choose_from
or
return_mask
:
self
.
label
=
key
if
key
is
not
None
else
f
'inputchoice_
{
uid
()
}
'
self
.
key
=
self
.
label
# deprecated, for backward compatibility
if
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
def
forward
(
self
,
candidate_inputs
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
...
...
@@ -86,20 +95,37 @@ class Placeholder(nn.Module):
class
ChosenInputs
(
nn
.
Module
):
def
__init__
(
self
,
chosen
:
int
):
"""
"""
def
__init__
(
self
,
chosen
:
List
[
int
],
reduction
:
str
):
super
().
__init__
()
self
.
chosen
=
chosen
self
.
reduction
=
reduction
def
forward
(
self
,
candidate_inputs
):
# TODO: support multiple chosen inputs
return
candidate_inputs
[
self
.
chosen
]
return
self
.
_tensor_reduction
(
self
.
reduction
,
[
candidate_inputs
[
i
]
for
i
in
self
.
chosen
])
def
_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
"none"
:
return
tensor_list
if
not
tensor_list
:
return
None
# empty. return None for now
if
len
(
tensor_list
)
==
1
:
return
tensor_list
[
0
]
if
reduction_type
==
"sum"
:
return
sum
(
tensor_list
)
if
reduction_type
==
"mean"
:
return
sum
(
tensor_list
)
/
len
(
tensor_list
)
if
reduction_type
==
"concat"
:
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
# the following are pytorch modules
class
Module
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Module
,
self
).
__init__
()
Module
=
nn
.
Module
class
Sequential
(
nn
.
Sequential
):
...
...
@@ -114,139 +140,116 @@ class ModuleList(nn.ModuleList):
super
(
ModuleList
,
self
).
__init__
(
*
args
)
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
arg
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
# TODO: support different versions of pytorch
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
ELU
=
wrap_module
(
nn
.
ELU
)
SELU
=
wrap_module
(
nn
.
SELU
)
CELU
=
wrap_module
(
nn
.
CELU
)
GLU
=
wrap_module
(
nn
.
GLU
)
GELU
=
wrap_module
(
nn
.
GELU
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
RNN
=
wrap_module
(
nn
.
RNN
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
GRU
=
wrap_module
(
nn
.
GRU
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
Flatten
=
wrap_module
(
nn
.
Flatten
)
#Unflatten = wrap_module(nn.Unflatten)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
#SiLU = wrap_module(nn.SiLU)
#TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
Identity
=
blackbox_module
(
nn
.
Identity
)
Linear
=
blackbox_module
(
nn
.
Linear
)
Conv1d
=
blackbox_module
(
nn
.
Conv1d
)
Conv2d
=
blackbox_module
(
nn
.
Conv2d
)
Conv3d
=
blackbox_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
blackbox_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
blackbox_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
blackbox_module
(
nn
.
ConvTranspose3d
)
Threshold
=
blackbox_module
(
nn
.
Threshold
)
ReLU
=
blackbox_module
(
nn
.
ReLU
)
Hardtanh
=
blackbox_module
(
nn
.
Hardtanh
)
ReLU6
=
blackbox_module
(
nn
.
ReLU6
)
Sigmoid
=
blackbox_module
(
nn
.
Sigmoid
)
Tanh
=
blackbox_module
(
nn
.
Tanh
)
Softmax
=
blackbox_module
(
nn
.
Softmax
)
Softmax2d
=
blackbox_module
(
nn
.
Softmax2d
)
LogSoftmax
=
blackbox_module
(
nn
.
LogSoftmax
)
ELU
=
blackbox_module
(
nn
.
ELU
)
SELU
=
blackbox_module
(
nn
.
SELU
)
CELU
=
blackbox_module
(
nn
.
CELU
)
GLU
=
blackbox_module
(
nn
.
GLU
)
GELU
=
blackbox_module
(
nn
.
GELU
)
Hardshrink
=
blackbox_module
(
nn
.
Hardshrink
)
LeakyReLU
=
blackbox_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
blackbox_module
(
nn
.
LogSigmoid
)
Softplus
=
blackbox_module
(
nn
.
Softplus
)
Softshrink
=
blackbox_module
(
nn
.
Softshrink
)
MultiheadAttention
=
blackbox_module
(
nn
.
MultiheadAttention
)
PReLU
=
blackbox_module
(
nn
.
PReLU
)
Softsign
=
blackbox_module
(
nn
.
Softsign
)
Softmin
=
blackbox_module
(
nn
.
Softmin
)
Tanhshrink
=
blackbox_module
(
nn
.
Tanhshrink
)
RReLU
=
blackbox_module
(
nn
.
RReLU
)
AvgPool1d
=
blackbox_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
blackbox_module
(
nn
.
AvgPool2d
)
AvgPool3d
=
blackbox_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
blackbox_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
blackbox_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
blackbox_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
blackbox_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
blackbox_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
blackbox_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
blackbox_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
blackbox_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
blackbox_module
(
nn
.
LPPool1d
)
LPPool2d
=
blackbox_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
blackbox_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
blackbox_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
blackbox_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
blackbox_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
blackbox_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
blackbox_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
blackbox_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
blackbox_module
(
nn
.
LayerNorm
)
GroupNorm
=
blackbox_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
blackbox_module
(
nn
.
SyncBatchNorm
)
Dropout
=
blackbox_module
(
nn
.
Dropout
)
Dropout2d
=
blackbox_module
(
nn
.
Dropout2d
)
Dropout3d
=
blackbox_module
(
nn
.
Dropout3d
)
AlphaDropout
=
blackbox_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
blackbox_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
blackbox_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
blackbox_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
blackbox_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
blackbox_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
blackbox_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
blackbox_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
blackbox_module
(
nn
.
Embedding
)
EmbeddingBag
=
blackbox_module
(
nn
.
EmbeddingBag
)
RNNBase
=
blackbox_module
(
nn
.
RNNBase
)
RNN
=
blackbox_module
(
nn
.
RNN
)
LSTM
=
blackbox_module
(
nn
.
LSTM
)
GRU
=
blackbox_module
(
nn
.
GRU
)
RNNCellBase
=
blackbox_module
(
nn
.
RNNCellBase
)
RNNCell
=
blackbox_module
(
nn
.
RNNCell
)
LSTMCell
=
blackbox_module
(
nn
.
LSTMCell
)
GRUCell
=
blackbox_module
(
nn
.
GRUCell
)
PixelShuffle
=
blackbox_module
(
nn
.
PixelShuffle
)
Upsample
=
blackbox_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
blackbox_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
blackbox_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
blackbox_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
blackbox_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
blackbox_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
blackbox_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
blackbox_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
blackbox_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
blackbox_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
blackbox_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
blackbox_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
blackbox_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
blackbox_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
blackbox_module
(
nn
.
ConstantPad3d
)
Bilinear
=
blackbox_module
(
nn
.
Bilinear
)
CosineSimilarity
=
blackbox_module
(
nn
.
CosineSimilarity
)
Unfold
=
blackbox_module
(
nn
.
Unfold
)
Fold
=
blackbox_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
blackbox_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
blackbox_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
blackbox_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
blackbox_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
blackbox_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
blackbox_module
(
nn
.
Transformer
)
Flatten
=
blackbox_module
(
nn
.
Flatten
)
Hardsigmoid
=
blackbox_module
(
nn
.
Hardsigmoid
)
if
version_larger_equal
(
torch
.
__version__
,
'1.6.0'
):
Hardswish
=
blackbox_module
(
nn
.
Hardswish
)
if
version_larger_equal
(
torch
.
__version__
,
'1.7.0'
):
SiLU
=
blackbox_module
(
nn
.
SiLU
)
Unflatten
=
blackbox_module
(
nn
.
Unflatten
)
TripletMarginWithDistanceLoss
=
blackbox_module
(
nn
.
TripletMarginWithDistanceLoss
)
nni/retiarii/operation.py
View file @
4784cc6c
...
...
@@ -121,6 +121,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
value
}
'
elif
self
.
type
==
'prim::ListConstruct'
:
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'prim::GetAttr'
:
return
f
"
{
output
}
=
{
self
.
parameters
[
'input'
]
}
.
{
self
.
parameters
[
'name'
]
}
"
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
...
...
@@ -133,8 +135,7 @@ class PyTorchOperation(Operation):
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
return
f
'
{
output
}
= '
+
' + '
.
join
(
inputs
)
elif
self
.
type
==
OpTypeName
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
...
...
@@ -151,6 +152,8 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
elif
self
.
type
==
'aten::Bool'
:
return
f
'
{
output
}
= bool(
{
inputs
[
0
]
}
)'
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
...
nni/retiarii/strategies/__init__.py
View file @
4784cc6c
from
.tpe_strategy
import
TPEStrategy
from
.random_strategy
import
RandomStrategy
nni/retiarii/strategies/random_strategy.py
0 → 100644
View file @
4784cc6c
import
logging
import
random
import
time
from
..
import
Sampler
,
submit_models
,
query_available_resources
from
.strategy
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
class
RandomSampler
(
Sampler
):
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
return
random
.
choice
(
candidates
)
class
RandomStrategy
(
BaseStrategy
):
def
__init__
(
self
):
self
.
random_sampler
=
RandomSampler
()
def
run
(
self
,
base_model
,
applied_mutators
):
_logger
.
info
(
'stargety start...'
)
while
True
:
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: %s'
,
str
(
applied_mutators
))
for
mutator
in
applied_mutators
:
mutator
.
bind_sampler
(
self
.
random_sampler
)
model
=
mutator
.
apply
(
model
)
# run models
submit_models
(
model
)
else
:
time
.
sleep
(
2
)
nni/retiarii/strategies/tpe_strategy.py
View file @
4784cc6c
import
logging
import
time
from
nni.algorithms.hpo.hyperopt_tuner
import
HyperoptTuner
from
..
import
Sampler
,
submit_models
,
wait_models
from
..
import
Sampler
,
submit_models
,
query_available_resources
,
is_stopped_exec
from
.strategy
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -39,6 +40,7 @@ class TPEStrategy(BaseStrategy):
def
__init__
(
self
):
self
.
tpe_sampler
=
TPESampler
()
self
.
model_id
=
0
self
.
running_models
=
{}
def
run
(
self
,
base_model
,
applied_mutators
):
sample_space
=
[]
...
...
@@ -48,9 +50,10 @@ class TPEStrategy(BaseStrategy):
sample_space
.
extend
(
recorded_candidates
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
try
:
_logger
.
info
(
'stargety start...'
)
while
True
:
_logger
.
info
(
'stargety start...'
)
while
True
:
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: %s'
,
str
(
applied_mutators
))
...
...
@@ -61,9 +64,18 @@ class TPEStrategy(BaseStrategy):
model
=
mutator
.
apply
(
model
)
# run models
submit_models
(
model
)
wait_models
(
model
)
self
.
tpe_sampler
.
receive_result
(
self
.
model_id
,
model
.
metric
)
self
.
running_models
[
self
.
model_id
]
=
model
self
.
model_id
+=
1
_logger
.
info
(
'Strategy says: %s'
,
model
.
metric
)
except
Exception
:
_logger
.
error
(
logging
.
exception
(
'message'
))
else
:
time
.
sleep
(
2
)
_logger
.
warning
(
'num of running models: %d'
,
len
(
self
.
running_models
))
to_be_deleted
=
[]
for
_id
,
_model
in
self
.
running_models
.
items
():
if
is_stopped_exec
(
_model
):
if
_model
.
metric
is
not
None
:
self
.
tpe_sampler
.
receive_result
(
_id
,
_model
.
metric
)
_logger
.
warning
(
'tpe receive results: %d, %s'
,
_id
,
_model
.
metric
)
to_be_deleted
.
append
(
_id
)
for
_id
in
to_be_deleted
:
del
self
.
running_models
[
_id
]
Prev
1
…
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment