Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
192a807b
Unverified
Commit
192a807b
authored
Dec 14, 2020
by
QuanluZhang
Committed by
GitHub
Dec 14, 2020
Browse files
[Retiarii] refactor based on the new launch approach (#3185)
parent
80394047
Changes
24
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
920 additions
and
164 deletions
+920
-164
nni/experiment/pipe.py
nni/experiment/pipe.py
+1
-1
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+1
-0
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+1
-6
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+62
-20
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+21
-17
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+0
-7
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+0
-5
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+192
-0
nni/retiarii/graph.py
nni/retiarii/graph.py
+24
-2
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+29
-1
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+172
-60
nni/retiarii/operation.py
nni/retiarii/operation.py
+2
-2
nni/retiarii/strategies/strategy.py
nni/retiarii/strategies/strategy.py
+1
-1
nni/retiarii/strategies/tpe_strategy.py
nni/retiarii/strategies/tpe_strategy.py
+1
-4
nni/retiarii/trainer/interface.py
nni/retiarii/trainer/interface.py
+0
-18
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+2
-4
nni/retiarii/utils.py
nni/retiarii/utils.py
+77
-16
test/retiarii_test/darts/darts_model.py
test/retiarii_test/darts/darts_model.py
+168
-0
test/retiarii_test/darts/ops.py
test/retiarii_test/darts/ops.py
+133
-0
test/retiarii_test/darts/test.py
test/retiarii_test/darts/test.py
+33
-0
No files found.
nni/experiment/pipe.py
View file @
192a807b
...
@@ -52,7 +52,7 @@ else:
...
@@ -52,7 +52,7 @@ else:
def
connect
(
self
)
->
BufferedIOBase
:
def
connect
(
self
)
->
BufferedIOBase
:
conn
,
_
=
self
.
_socket
.
accept
()
conn
,
_
=
self
.
_socket
.
accept
()
self
.
file
=
conn
.
makefile
(
'w
+
b'
)
self
.
file
=
conn
.
makefile
(
'
r
wb'
)
return
self
.
file
return
self
.
file
def
close
(
self
)
->
None
:
def
close
(
self
)
->
None
:
...
...
nni/retiarii/__init__.py
View file @
192a807b
...
@@ -2,3 +2,4 @@ from .operation import Operation
...
@@ -2,3 +2,4 @@ from .operation import Operation
from
.graph
import
*
from
.graph
import
*
from
.execution
import
*
from
.execution
import
*
from
.mutator
import
*
from
.mutator
import
*
from
.utils
import
register_module
\ No newline at end of file
nni/retiarii/codegen/pytorch.py
View file @
192a807b
...
@@ -15,7 +15,6 @@ def model_to_pytorch_script(model: Model, placement = None) -> str:
...
@@ -15,7 +15,6 @@ def model_to_pytorch_script(model: Model, placement = None) -> str:
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
,
placement
=
placement
)
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
,
placement
=
placement
)
graphs
.
append
(
graph_code
)
graphs
.
append
(
graph_code
)
total_pkgs
.
update
(
import_pkgs
)
total_pkgs
.
update
(
import_pkgs
)
# FIXME: set correct PATH for the packages (after launch refactor)
pkgs_code
=
'
\n
'
.
join
([
'import {}'
.
format
(
pkg
)
for
pkg
in
total_pkgs
])
pkgs_code
=
'
\n
'
.
join
([
'import {}'
.
format
(
pkg
)
for
pkg
in
total_pkgs
])
return
_PyTorchScriptTemplate
.
format
(
pkgs_code
,
'
\n\n
'
.
join
(
graphs
)).
strip
()
return
_PyTorchScriptTemplate
.
format
(
pkgs_code
,
'
\n\n
'
.
join
(
graphs
)).
strip
()
...
@@ -71,7 +70,7 @@ def _remove_prefix(names, graph_name):
...
@@ -71,7 +70,7 @@ def _remove_prefix(names, graph_name):
return
names
[
len
(
graph_name
):]
if
names
.
startswith
(
graph_name
)
else
names
return
names
[
len
(
graph_name
):]
if
names
.
startswith
(
graph_name
)
else
names
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
,
placement
=
None
)
->
str
:
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
,
placement
=
None
)
->
str
:
nodes
=
graph
.
topo_sort
()
# FIXME: topological sort is needed here
nodes
=
graph
.
topo_sort
()
# handle module node and function node differently
# handle module node and function node differently
# only need to generate code for module here
# only need to generate code for module here
...
@@ -130,10 +129,6 @@ import torch.nn as nn
...
@@ -130,10 +129,6 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
import torch.optim as optim
import torch.optim as optim
# FIXME: remove these two lines
import sys
sys.path.append("test/convert_test")
{}
{}
{}
{}
...
...
nni/retiarii/converter/graph_gen.py
View file @
192a807b
import
json_tricks
import
json_tricks
import
logging
import
re
import
re
import
torch
import
torch
...
@@ -6,9 +7,10 @@ from ..graph import Graph, Node, Edge, Model
...
@@ -6,9 +7,10 @@ from ..graph import Graph, Node, Edge, Model
from
..operation
import
Cell
,
Operation
from
..operation
import
Cell
,
Operation
from
..nn.pytorch
import
Placeholder
,
LayerChoice
,
InputChoice
from
..nn.pytorch
import
Placeholder
,
LayerChoice
,
InputChoice
from
.op_types
import
MODULE_EXCEPT_LIST
,
Type
from
.op_types
import
MODULE_EXCEPT_LIST
,
Op
Type
Name
,
BasicOpsPT
from
.utils
import
build_full_name
,
_convert_name
from
.utils
import
build_full_name
,
_convert_name
_logger
=
logging
.
getLogger
(
__name__
)
global_seq
=
0
global_seq
=
0
global_graph_id
=
0
global_graph_id
=
0
...
@@ -80,7 +82,7 @@ def create_prim_constant_node(ir_graph, node, module_name):
...
@@ -80,7 +82,7 @@ def create_prim_constant_node(ir_graph, node, module_name):
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
if
node
.
outputsAt
(
0
).
toIValue
()
is
not
None
:
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
attrs
=
{
'value'
:
node
.
outputsAt
(
0
).
toIValue
()}
global_seq
+=
1
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
Type
.
Constant
,
global_seq
),
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
Op
Type
Name
.
Constant
,
global_seq
),
node
.
kind
(),
attrs
)
node
.
kind
(),
attrs
)
return
new_node
return
new_node
...
@@ -163,6 +165,33 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
...
@@ -163,6 +165,33 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# key: tensor (%out.1), value: node (this node)
# key: tensor (%out.1), value: node (this node)
output_remap
=
{}
output_remap
=
{}
def
handle_if_condition
(
cond_tensor
):
"""
to calculate the condition, we only deal with the following op types by tracing back
`prim::GetAttr`, `aten::__getitem__`, `prim::Constant`, `aten::eq`
generate the expression using recursive calls
NOTE: do not support dynamic graph
"""
def
_generate_expr
(
tensor
):
if
tensor
.
node
().
kind
()
==
'prim::GetAttr'
:
return
f
'(
{
getattr
(
module
,
tensor
.
node
().
s
(
"name"
))
}
)'
elif
tensor
.
node
().
kind
()
==
'aten::__getitem__'
:
t
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
idx
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
t
}
[
{
idx
}
])'
elif
tensor
.
node
().
kind
()
==
'prim::Constant'
:
return
f
'
{
tensor
.
toIValue
()
}
'
elif
tensor
.
node
().
kind
()
==
'aten::eq'
:
left
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
0
))
right
=
_generate_expr
(
tensor
.
node
().
inputsAt
(
1
))
return
f
'(
{
left
}
==
{
right
}
)'
else
:
raise
RuntimeError
(
f
'Unsupported op type
{
tensor
.
node
().
kind
()
}
in if condition'
)
expr
=
_generate_expr
(
cond_tensor
)
return
eval
(
expr
)
def
handle_if_node
(
node
):
def
handle_if_node
(
node
):
"""
"""
Parameters
Parameters
...
@@ -179,19 +208,13 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
...
@@ -179,19 +208,13 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# will support constant expression in future
# will support constant expression in future
inputs
=
[
i
for
i
in
node
.
inputs
()]
inputs
=
[
i
for
i
in
node
.
inputs
()]
assert
len
(
inputs
)
==
1
assert
len
(
inputs
)
==
1
if
not
inputs
[
0
].
node
().
kind
()
in
[
'prim::Constant'
,
'prim::GetAttr'
]:
cond
=
handle_if_condition
(
inputs
[
0
])
raise
RuntimeError
(
'"if" whose condition is not constant or attribute has not been supported yet!'
)
chosen_block
=
0
if
cond
else
1
chosen_block
=
None
if
inputs
[
0
].
node
().
kind
()
==
'prim::Constant'
:
chosen_block
=
0
if
inputs
[
0
].
toIValue
()
else
1
if
inputs
[
0
].
node
().
kind
()
==
'prim::GetAttr'
:
chosen_block
=
0
if
getattr
(
module
,
inputs
[
0
].
node
().
s
(
'name'
))
else
1
blocks
=
[
block
for
block
in
node
.
blocks
()]
blocks
=
[
block
for
block
in
node
.
blocks
()]
assert
len
(
blocks
)
==
2
assert
len
(
blocks
)
==
2
last_block_node
=
None
last_block_node
=
None
for
node
in
blocks
[
chosen_block
].
nodes
():
for
node
in
blocks
[
chosen_block
].
nodes
():
last_block_node
=
handle_single_node
(
node
)
last_block_node
=
handle_single_node
(
node
)
assert
last_block_node
is
not
None
return
last_block_node
return
last_block_node
def
handle_single_node
(
node
):
def
handle_single_node
(
node
):
...
@@ -287,29 +310,33 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
...
@@ -287,29 +310,33 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
node_index
[
node
]
=
new_node
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::ListConstruct'
:
elif
node
.
kind
()
==
'prim::ListConstruct'
:
global_seq
+=
1
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
Type
.
ListConstruct
,
global_seq
),
node
.
kind
())
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
Op
Type
Name
.
ListConstruct
,
global_seq
),
node
.
kind
())
node_index
[
node
]
=
new_node
node_index
[
node
]
=
new_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
)
elif
node
.
kind
()
==
'aten::append'
:
elif
node
.
kind
()
==
'aten::append'
:
global_seq
+=
1
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
Type
.
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
output_remap
[
node
.
inputsAt
(
0
)]
=
node
output_remap
[
node
.
inputsAt
(
0
)]
=
node
elif
node
.
kind
().
startswith
(
'aten::'
):
elif
node
.
kind
().
startswith
(
'aten::'
):
# handle aten::XXX
# handle aten::XXX
global_seq
+=
1
global_seq
+=
1
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
Type
.
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
aten_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
BasicOpsPT
[
node
.
kind
()],
global_seq
),
node
.
kind
())
node_index
[
node
]
=
aten_node
node_index
[
node
]
=
aten_node
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
aten_node
,
output_remap
)
elif
node
.
kind
()
==
'prim::GetAttr'
:
elif
node
.
kind
()
==
'prim::GetAttr'
:
node_type
,
attrs
=
handle_prim_attr_node
(
node
)
node_type
,
attrs
=
handle_prim_attr_node
(
node
)
global_seq
+=
1
global_seq
+=
1
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
Type
.
Attr
,
global_seq
),
new_node
=
ir_graph
.
add_node
(
build_full_name
(
module_name
,
Op
Type
Name
.
Attr
,
global_seq
),
node_type
,
attrs
)
node_type
,
attrs
)
node_index
[
node
]
=
new_node
node_index
[
node
]
=
new_node
elif
node
.
kind
()
==
'prim::min'
:
print
(
'zql: '
,
sm_graph
)
exit
(
1
)
elif
node
.
kind
()
==
'prim::If'
:
elif
node
.
kind
()
==
'prim::If'
:
last_block_node
=
handle_if_node
(
node
)
last_block_node
=
handle_if_node
(
node
)
# last_block_node is None means no node in the branch block
node_index
[
node
]
=
last_block_node
node_index
[
node
]
=
last_block_node
elif
node
.
kind
()
==
'prim::Loop'
:
elif
node
.
kind
()
==
'prim::Loop'
:
raise
RuntimeError
(
'Loop has not been supported yet!'
)
raise
RuntimeError
(
'Loop has not been supported yet!'
)
...
@@ -343,7 +370,10 @@ def merge_aten_slices(ir_graph):
...
@@ -343,7 +370,10 @@ def merge_aten_slices(ir_graph):
for
head_node
in
head_slice_nodes
:
for
head_node
in
head_slice_nodes
:
slot
=
0
slot
=
0
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
Type
.
MergedSlice
)
new_slice_node
=
ir_graph
.
add_node
(
build_full_name
(
head_node
.
name
,
'merged'
),
OpTypeName
.
MergedSlice
)
if
len
(
head_node
.
incoming_edges
)
==
4
:
# when slice is for one dimension list, there are only 4 inputs, thus merge is not needed
break
assert
len
(
head_node
.
incoming_edges
)
==
5
assert
len
(
head_node
.
incoming_edges
)
==
5
for
edge
in
head_node
.
incoming_edges
:
for
edge
in
head_node
.
incoming_edges
:
edge
.
tail
=
new_slice_node
edge
.
tail
=
new_slice_node
...
@@ -383,10 +413,13 @@ def _handle_layerchoice(module):
...
@@ -383,10 +413,13 @@ def _handle_layerchoice(module):
m_attrs
=
{}
m_attrs
=
{}
candidates
=
module
.
candidate_ops
candidates
=
module
.
candidate_ops
choices
=
[]
for
i
,
cand
in
enumerate
(
candidates
):
for
i
,
cand
in
enumerate
(
candidates
):
assert
id
(
cand
)
in
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
id
(
cand
)
in
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
modules_arg
[
id
(
cand
)],
dict
)
assert
isinstance
(
modules_arg
[
id
(
cand
)],
dict
)
m_attrs
[
f
'choice_
{
i
}
'
]
=
modules_arg
[
id
(
cand
)]
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
modules_arg
[
id
(
cand
)]})
m_attrs
[
f
'choices'
]
=
choices
m_attrs
[
'label'
]
=
module
.
label
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
return
m_attrs
...
@@ -425,17 +458,18 @@ def convert_module(script_module, module, module_name, ir_model):
...
@@ -425,17 +458,18 @@ def convert_module(script_module, module, module_name, ir_model):
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# NOTE: have not supported nested LayerChoice, i.e., a candidate module
# also has LayerChoice or InputChoice or ValueChoice
# also has LayerChoice or InputChoice or ValueChoice
original_type_name
=
script_module
.
original_name
original_type_name
=
script_module
.
original_name
if
original_type_name
==
Type
.
LayerChoice
:
if
original_type_name
==
Op
Type
Name
.
LayerChoice
:
m_attrs
=
_handle_layerchoice
(
module
)
m_attrs
=
_handle_layerchoice
(
module
)
return
None
,
m_attrs
return
None
,
m_attrs
if
original_type_name
==
Type
.
InputChoice
:
if
original_type_name
==
Op
Type
Name
.
InputChoice
:
m_attrs
=
_handle_inputchoice
(
module
)
m_attrs
=
_handle_inputchoice
(
module
)
return
None
,
m_attrs
return
None
,
m_attrs
if
original_type_name
in
Typ
e
.
Placeholder
:
if
original_type_name
==
OpTypeNam
e
.
Placeholder
:
m_attrs
=
modules_arg
[
id
(
module
)]
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
return
None
,
m_attrs
if
original_type_name
in
torch
.
nn
.
__dict__
and
original_type_name
not
in
MODULE_EXCEPT_LIST
:
if
original_type_name
in
torch
.
nn
.
__dict__
and
original_type_name
not
in
MODULE_EXCEPT_LIST
:
# this is a basic module from pytorch, no need to parse its graph
# this is a basic module from pytorch, no need to parse its graph
assert
id
(
module
)
in
modules_arg
,
f
'
{
original_type_name
}
arguments are not recorded'
m_attrs
=
modules_arg
[
id
(
module
)]
m_attrs
=
modules_arg
[
id
(
module
)]
return
None
,
m_attrs
return
None
,
m_attrs
...
@@ -463,7 +497,15 @@ def convert_module(script_module, module, module_name, ir_model):
...
@@ -463,7 +497,15 @@ def convert_module(script_module, module, module_name, ir_model):
ir_graph
.
_register
()
ir_graph
.
_register
()
return
ir_graph
,
modules_arg
[
id
(
module
)]
if
id
(
module
)
not
in
modules_arg
:
raise
RuntimeError
(
f
'
{
original_type_name
}
arguments are not recorded,
\
you might have forgotten to decorate this class with @register_module()'
)
# TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further.
return
ir_graph
,
{}
def
convert_to_graph
(
script_module
,
module
,
recorded_modules_arg
):
def
convert_to_graph
(
script_module
,
module
,
recorded_modules_arg
):
"""
"""
...
...
nni/retiarii/converter/op_types.py
View file @
192a807b
from
enum
import
Enum
MODULE_EXCEPT_LIST
=
[
'Sequential'
]
MODULE_EXCEPT_LIST
=
[
'Sequential'
]
class
Type
:
class
OpTypeName
(
str
,
Enum
):
"""Node Type class
"""
op type to its type name str
"""
"""
Attr
=
'Attr'
Attr
=
'Attr'
Constant
=
'Constant'
Constant
=
'Constant'
...
@@ -11,11 +14,10 @@ class Type:
...
@@ -11,11 +14,10 @@ class Type:
InputChoice
=
'InputChoice'
InputChoice
=
'InputChoice'
ValueChoice
=
'ValueChoice'
ValueChoice
=
'ValueChoice'
Placeholder
=
'Placeholder'
Placeholder
=
'Placeholder'
MergedSlice
=
'MergedSlice'
MergedSlice
=
'MergedSlice'
# deal with aten op
# deal with aten op
BasicOpsPT
=
{
BasicOpsPT
=
{
'aten::mean'
:
'Mean'
,
'aten::mean'
:
'Mean'
,
'aten::relu'
:
'Relu'
,
'aten::relu'
:
'Relu'
,
'aten::add'
:
'Add'
,
'aten::add'
:
'Add'
,
...
@@ -25,7 +27,9 @@ class Type:
...
@@ -25,7 +27,9 @@ class Type:
'aten::slice'
:
'Slice'
,
'aten::slice'
:
'Slice'
,
'aten::cat'
:
'Cat'
,
'aten::cat'
:
'Cat'
,
'aten::size'
:
'Size'
,
'aten::size'
:
'Size'
,
'aten::view'
:
'View'
'aten::view'
:
'View'
,
}
'aten::eq'
:
'Eq'
,
'aten::add_'
:
'Add_'
# %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
}
BasicOpsTF
=
{}
BasicOpsTF
=
{}
\ No newline at end of file
\ No newline at end of file
nni/retiarii/execution/api.py
View file @
192a807b
...
@@ -36,13 +36,6 @@ def get_and_register_default_listener(engine: AbstractExecutionEngine) -> Defaul
...
@@ -36,13 +36,6 @@ def get_and_register_default_listener(engine: AbstractExecutionEngine) -> Defaul
engine
.
register_graph_listener
(
_default_listener
)
engine
.
register_graph_listener
(
_default_listener
)
return
_default_listener
return
_default_listener
def
_get_search_space
()
->
'Dict'
:
engine
=
get_execution_engine
()
while
True
:
time
.
sleep
(
1
)
if
engine
.
get_search_space
()
is
not
None
:
break
return
engine
.
get_search_space
()
def
submit_models
(
*
models
:
Model
)
->
None
:
def
submit_models
(
*
models
:
Model
)
->
None
:
engine
=
get_execution_engine
()
engine
=
get_execution_engine
()
...
...
nni/retiarii/execution/base.py
View file @
192a807b
...
@@ -50,10 +50,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -50,10 +50,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
def
get_search_space
(
self
)
->
'JSON'
:
advisor
=
get_advisor
()
return
advisor
.
search_space
def
submit_models
(
self
,
*
models
:
Model
)
->
None
:
def
submit_models
(
self
,
*
models
:
Model
)
->
None
:
for
model
in
models
:
for
model
in
models
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
),
...
@@ -106,7 +102,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -106,7 +102,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Initialize the model, hand it over to trainer.
Initialize the model, hand it over to trainer.
"""
"""
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
# FIXME: update this part to dump code to a correct path!!!
with
open
(
'_generated_model.py'
,
'w'
)
as
f
:
with
open
(
'_generated_model.py'
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
f
.
write
(
graph_data
.
model_script
)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
...
...
nni/retiarii/experiment.py
0 → 100644
View file @
192a807b
import
dataclasses
import
logging
import
time
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
threading
import
Thread
from
typing
import
Any
,
List
,
Optional
from
..experiment
import
Experiment
,
TrainingServiceConfig
from
..experiment
import
launcher
,
rest
from
..experiment.config.base
import
ConfigBase
,
PathLike
from
..experiment.config
import
util
from
.utils
import
get_records
from
.integration
import
RetiariiAdvisor
from
.converter.graph_gen
import
convert_to_graph
from
.mutator
import
LayerChoiceMutator
,
InputChoiceMutator
_logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
search_space
:
Any
=
''
# TODO: remove
trial_command
:
str
=
'python3 -m nni.retiarii.trial_entry'
trial_code_directory
:
PathLike
=
'.'
trial_concurrency
:
int
trial_gpu_number
:
int
=
0
max_experiment_duration
:
Optional
[
str
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
nni_manager_ip
:
Optional
[
str
]
=
None
debug
:
bool
=
False
log_level
:
Optional
[
str
]
=
None
experiment_working_directory
:
Optional
[
PathLike
]
=
None
# remove configuration of tuner/assessor/advisor
training_service
:
TrainingServiceConfig
def
__init__
(
self
,
training_service_platform
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
training_service_platform
is
not
None
:
assert
'training_service'
not
in
kwargs
self
.
training_service
=
util
.
training_service_config_factory
(
training_service_platform
)
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
super
().
validate
()
@
property
def
_canonical_rules
(
self
):
return
_canonical_rules
@
property
def
_validation_rules
(
self
):
return
_validation_rules
_canonical_rules
=
{
'trial_code_directory'
:
util
.
canonical_path
,
'max_experiment_duration'
:
lambda
value
:
f
'
{
util
.
parse_time
(
value
)
}
s'
if
value
is
not
None
else
None
,
'experiment_working_directory'
:
util
.
canonical_path
}
_validation_rules
=
{
'trial_code_directory'
:
lambda
value
:
(
Path
(
value
).
is_dir
(),
f
'"
{
value
}
" does not exist or is not directory'
),
'trial_concurrency'
:
lambda
value
:
value
>
0
,
'trial_gpu_number'
:
lambda
value
:
value
>=
0
,
'max_experiment_duration'
:
lambda
value
:
util
.
parse_time
(
value
)
>
0
,
'max_trial_number'
:
lambda
value
:
value
>
0
,
'log_level'
:
lambda
value
:
value
in
[
"trace"
,
"debug"
,
"info"
,
"warning"
,
"error"
,
"fatal"
],
'training_service'
:
lambda
value
:
(
type
(
value
)
is
not
TrainingServiceConfig
,
'cannot be abstract base class'
)
}
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
'nn.Module'
,
trainer
:
'BaseTrainer'
,
applied_mutators
:
List
[
'Mutator'
],
strategy
:
'BaseStrategy'
):
self
.
config
:
RetiariiExeConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
base_model
=
base_model
self
.
trainer
=
trainer
self
.
applied_mutators
=
applied_mutators
self
.
strategy
=
strategy
self
.
recorded_module_args
=
get_records
()
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
def
_process_inline_mutation
(
self
,
base_model
):
"""
the mutators are order independent
"""
lc_nodes
=
base_model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.nn.LayerChoice'
)
ic_nodes
=
base_model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.nn.InputChoice'
)
if
not
lc_nodes
and
not
ic_nodes
:
return
None
applied_mutators
=
[]
for
node
in
lc_nodes
:
mutator
=
LayerChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'choices'
])
applied_mutators
.
append
(
mutator
)
for
node
in
ic_nodes
:
mutator
=
InputChoiceMutator
(
node
.
name
,
node
.
operation
.
parameters
[
'n_chosen'
])
applied_mutators
.
append
(
mutator
)
return
applied_mutators
def
_start_strategy
(
self
):
import
torch
try
:
script_module
=
torch
.
jit
.
script
(
self
.
base_model
)
except
Exception
as
e
:
_logger
.
error
(
'Your base model cannot be parsed by torch.jit.script, please fix the following error:'
)
raise
e
base_model
=
convert_to_graph
(
script_module
,
self
.
base_model
,
self
.
recorded_module_args
)
assert
id
(
self
.
trainer
)
in
self
.
recorded_module_args
trainer_config
=
self
.
recorded_module_args
[
id
(
self
.
trainer
)]
base_model
.
apply_trainer
(
trainer_config
[
'modulename'
],
trainer_config
[
'args'
])
# handle inline mutations
mutators
=
self
.
_process_inline_mutation
(
base_model
)
if
mutators
is
not
None
and
self
.
applied_mutators
:
raise
RuntimeError
(
'Have not supported mixed usage of LayerChoice/InputChoice and mutators,
\
do not use mutators when you use LayerChoice/InputChoice'
)
if
mutators
is
not
None
:
self
.
applied_mutators
=
mutators
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
)).
start
()
_logger
.
info
(
'Strategy started!'
)
def
start
(
self
,
config
:
RetiariiExeConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
# FIXME:
if
debug
:
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
config
,
port
,
debug
)
assert
self
.
_proc
is
not
None
assert
self
.
_pipe
is
not
None
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread
(
target
=
self
.
_dispatcher
.
run
).
start
()
self
.
_start_strategy
()
# TODO: register experiment management metadata
def
stop
(
self
)
->
None
:
"""
Stop background experiment.
"""
self
.
_proc
.
kill
()
self
.
_pipe
.
close
()
self
.
port
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
def
run
(
self
,
config
:
RetiariiExeConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self
.
config
=
config
self
.
start
(
config
,
port
,
debug
)
try
:
while
True
:
time
.
sleep
(
10
)
status
=
self
.
get_status
()
# TODO: double check the status
if
status
in
[
'ERROR'
,
'STOPPED'
,
'NO_MORE_TRIAL'
]:
return
status
finally
:
self
.
stop
()
def
get_status
(
self
)
->
str
:
if
self
.
port
is
None
:
raise
RuntimeError
(
'Experiment is not running'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
return
resp
[
'status'
]
\ No newline at end of file
nni/retiarii/graph.py
View file @
192a807b
...
@@ -169,10 +169,29 @@ class Model:
...
@@ -169,10 +169,29 @@ class Model:
matched_nodes
.
extend
(
nodes
)
matched_nodes
.
extend
(
nodes
)
return
matched_nodes
return
matched_nodes
def
get_
by_nam
e
(
self
,
name
:
str
)
->
Union
[
'Graph'
,
'Node'
]:
def
get_
nodes_by_typ
e
(
self
,
type_
name
:
str
)
->
List
[
'Node'
]:
"""
"""
Find the graph or node that have the given name space nam
e.
Traverse all the nodes to find the matched node(s) with the given typ
e.
"""
"""
matched_nodes
=
[]
for
graph
in
self
.
graphs
.
values
():
nodes
=
graph
.
get_nodes_by_type
(
type_name
)
matched_nodes
.
extend
(
nodes
)
return
matched_nodes
def
get_node_by_name
(
self
,
node_name
:
str
)
->
'Node'
:
"""
Traverse all the nodes to find the matched node with the given name.
"""
matched_nodes
=
[]
for
graph
in
self
.
graphs
.
values
():
nodes
=
graph
.
get_nodes_by_name
(
node_name
)
matched_nodes
.
extend
(
nodes
)
assert
len
(
matched_nodes
)
<=
1
if
matched_nodes
:
return
matched_nodes
[
0
]
else
:
return
None
class
ModelStatus
(
Enum
):
class
ModelStatus
(
Enum
):
...
@@ -326,6 +345,9 @@ class Graph:
...
@@ -326,6 +345,9 @@ class Graph:
def
get_nodes_by_label
(
self
,
label
:
str
)
->
List
[
'Node'
]:
def
get_nodes_by_label
(
self
,
label
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
label
==
label
]
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
label
==
label
]
def
get_nodes_by_name
(
self
,
name
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
name
==
name
]
def
topo_sort
(
self
)
->
List
[
'Node'
]:
def
topo_sort
(
self
)
->
List
[
'Node'
]:
node_to_fanin
=
{}
node_to_fanin
=
{}
curr_nodes
=
[]
curr_nodes
=
[]
...
...
nni/retiarii/mutator.py
View file @
192a807b
...
@@ -102,3 +102,31 @@ class _RecorderSampler(Sampler):
...
@@ -102,3 +102,31 @@ class _RecorderSampler(Sampler):
def
choice
(
self
,
candidates
:
List
[
Choice
],
*
args
)
->
Choice
:
def
choice
(
self
,
candidates
:
List
[
Choice
],
*
args
)
->
Choice
:
self
.
recorded_candidates
.
append
(
candidates
)
self
.
recorded_candidates
.
append
(
candidates
)
return
candidates
[
0
]
return
candidates
[
0
]
# the following is for inline mutation
class
LayerChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
candidates
:
List
):
super
().
__init__
()
self
.
node_name
=
node_name
self
.
candidates
=
candidates
def
mutate
(
self
,
model
):
target
=
model
.
get_node_by_name
(
self
.
node_name
)
indexes
=
[
i
for
i
in
range
(
len
(
self
.
candidates
))]
chosen_index
=
self
.
choice
(
indexes
)
chosen_cand
=
self
.
candidates
[
chosen_index
]
target
.
update_operation
(
chosen_cand
[
'type'
],
chosen_cand
[
'parameters'
])
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
n_chosen
:
int
):
super
().
__init__
()
self
.
node_name
=
node_name
self
.
n_chosen
=
n_chosen
def
mutate
(
self
,
model
):
target
=
model
.
get_node_by_name
(
self
.
node_name
)
candidates
=
[
i
for
i
in
range
(
self
.
n_chosen
)]
chosen
=
self
.
choice
(
candidates
)
target
.
update_operation
(
'__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs'
,
{
'chosen'
:
chosen
})
nni/retiarii/nn/pytorch/nn.py
View file @
192a807b
...
@@ -4,47 +4,58 @@ import torch
...
@@ -4,47 +4,58 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
(
Any
,
Tuple
,
List
,
Optional
)
from
typing
import
(
Any
,
Tuple
,
List
,
Optional
)
_logger
=
logging
.
getLogger
(
__name__
)
from
...utils
import
add_record
_logger
.
setLevel
(
logging
.
INFO
)
_records
=
None
def
enable_record_args
():
global
_records
_records
=
{}
_logger
.
info
(
'args recording enabled'
)
def
disable_record_args
():
_logger
=
logging
.
getLogger
(
__name__
)
global
_records
_records
=
None
_logger
.
info
(
'args recording disabled'
)
def
get_records
():
global
_records
return
_records
def
add_record
(
name
,
value
):
__all__
=
[
global
_records
'LayerChoice'
,
'InputChoice'
,
'Placeholder'
,
if
_records
is
not
None
:
'Module'
,
'Sequential'
,
'ModuleList'
,
# TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
assert
name
not
in
_records
,
'{} already in _records'
.
format
(
name
)
'Identity'
,
'Linear'
,
'Conv1d'
,
'Conv2d'
,
'Conv3d'
,
'ConvTranspose1d'
,
_records
[
name
]
=
value
'ConvTranspose2d'
,
'ConvTranspose3d'
,
'Threshold'
,
'ReLU'
,
'Hardtanh'
,
'ReLU6'
,
'Sigmoid'
,
'Tanh'
,
'Softmax'
,
'Softmax2d'
,
'LogSoftmax'
,
'ELU'
,
'SELU'
,
'CELU'
,
'GLU'
,
'GELU'
,
'Hardshrink'
,
'LeakyReLU'
,
'LogSigmoid'
,
'Softplus'
,
'Softshrink'
,
'MultiheadAttention'
,
'PReLU'
,
'Softsign'
,
'Softmin'
,
'Tanhshrink'
,
'RReLU'
,
'AvgPool1d'
,
'AvgPool2d'
,
'AvgPool3d'
,
'MaxPool1d'
,
'MaxPool2d'
,
'MaxPool3d'
,
'MaxUnpool1d'
,
'MaxUnpool2d'
,
'MaxUnpool3d'
,
'FractionalMaxPool2d'
,
"FractionalMaxPool3d"
,
'LPPool1d'
,
'LPPool2d'
,
'LocalResponseNorm'
,
'BatchNorm1d'
,
'BatchNorm2d'
,
'BatchNorm3d'
,
'InstanceNorm1d'
,
'InstanceNorm2d'
,
'InstanceNorm3d'
,
'LayerNorm'
,
'GroupNorm'
,
'SyncBatchNorm'
,
'Dropout'
,
'Dropout2d'
,
'Dropout3d'
,
'AlphaDropout'
,
'FeatureAlphaDropout'
,
'ReflectionPad1d'
,
'ReflectionPad2d'
,
'ReplicationPad2d'
,
'ReplicationPad1d'
,
'ReplicationPad3d'
,
'CrossMapLRN2d'
,
'Embedding'
,
'EmbeddingBag'
,
'RNNBase'
,
'RNN'
,
'LSTM'
,
'GRU'
,
'RNNCellBase'
,
'RNNCell'
,
'LSTMCell'
,
'GRUCell'
,
'PixelShuffle'
,
'Upsample'
,
'UpsamplingNearest2d'
,
'UpsamplingBilinear2d'
,
'PairwiseDistance'
,
'AdaptiveMaxPool1d'
,
'AdaptiveMaxPool2d'
,
'AdaptiveMaxPool3d'
,
'AdaptiveAvgPool1d'
,
'AdaptiveAvgPool2d'
,
'AdaptiveAvgPool3d'
,
'TripletMarginLoss'
,
'ZeroPad2d'
,
'ConstantPad1d'
,
'ConstantPad2d'
,
'ConstantPad3d'
,
'Bilinear'
,
'CosineSimilarity'
,
'Unfold'
,
'Fold'
,
'AdaptiveLogSoftmaxWithLoss'
,
'TransformerEncoder'
,
'TransformerDecoder'
,
'TransformerEncoderLayer'
,
'TransformerDecoderLayer'
,
'Transformer'
,
#'LazyLinear', 'LazyConv1d', 'LazyConv2d', 'LazyConv3d',
#'LazyConvTranspose1d', 'LazyConvTranspose2d', 'LazyConvTranspose3d',
#'Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss', 'ChannelShuffle',
'Flatten'
,
'Hardsigmoid'
,
'Hardswish'
]
class
LayerChoice
(
nn
.
Module
):
class
LayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
candidate
_ops
:
List
,
label
:
str
=
None
):
def
__init__
(
self
,
op_
candidate
s
,
reduction
=
None
,
return_mask
=
False
,
key
=
None
):
super
(
LayerChoice
,
self
).
__init__
()
super
(
LayerChoice
,
self
).
__init__
()
self
.
candidate_ops
=
candidate_ops
self
.
candidate_ops
=
op_candidates
self
.
label
=
label
self
.
label
=
key
if
reduction
or
return_mask
:
_logger
.
warning
(
'input arguments `reduction` and `return_mask` are deprecated!'
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
x
return
x
class
InputChoice
(
nn
.
Module
):
class
InputChoice
(
nn
.
Module
):
def
__init__
(
self
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
str
=
None
):
def
__init__
(
self
,
n_candidates
=
None
,
choose_from
=
None
,
n_chosen
=
1
,
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
super
(
InputChoice
,
self
).
__init__
()
super
(
InputChoice
,
self
).
__init__
()
self
.
n_chosen
=
n_chosen
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
label
=
label
self
.
label
=
key
if
n_candidates
or
choose_from
or
return_mask
:
_logger
.
warning
(
'input arguments `n_candidates`, `choose_from` and `return_mask` are deprecated!'
)
def
forward
(
self
,
candidate_inputs
:
List
[
'Tensor'
])
->
'Tensor'
:
def
forward
(
self
,
candidate_inputs
:
List
[
'Tensor'
])
->
'Tensor'
:
# fake return
# fake return
...
@@ -62,9 +73,7 @@ class ValueChoice:
...
@@ -62,9 +73,7 @@ class ValueChoice:
class
Placeholder
(
nn
.
Module
):
class
Placeholder
(
nn
.
Module
):
def
__init__
(
self
,
label
,
related_info
):
def
__init__
(
self
,
label
,
related_info
):
global
_records
add_record
(
id
(
self
),
related_info
)
if
_records
is
not
None
:
_records
[
id
(
self
)]
=
related_info
self
.
label
=
label
self
.
label
=
label
self
.
related_info
=
related_info
self
.
related_info
=
related_info
super
(
Placeholder
,
self
).
__init__
()
super
(
Placeholder
,
self
).
__init__
()
...
@@ -72,34 +81,29 @@ class Placeholder(nn.Module):
...
@@ -72,34 +81,29 @@ class Placeholder(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
return
x
return
x
class
ChosenInputs
(
nn
.
Module
):
def
__init__
(
self
,
chosen
:
int
):
super
().
__init__
()
self
.
chosen
=
chosen
def
forward
(
self
,
candidate_inputs
):
# TODO: support multiple chosen inputs
return
candidate_inputs
[
self
.
chosen
]
# the following are pytorch modules
class
Module
(
nn
.
Module
):
class
Module
(
nn
.
Module
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
):
# TODO: users have to pass init's arguments to super init's arguments
global
_records
if
_records
is
not
None
:
assert
not
kwargs
argname_list
=
list
(
inspect
.
signature
(
self
.
__class__
).
parameters
.
keys
())
assert
len
(
argname_list
)
==
len
(
args
),
'Error: {} not put input arguments in its super().__init__ function'
.
format
(
self
.
__class__
)
full_args
=
{}
for
i
,
arg_value
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
args
[
i
]
_records
[
id
(
self
)]
=
full_args
#print('my module: ', id(self), args, kwargs)
super
(
Module
,
self
).
__init__
()
super
(
Module
,
self
).
__init__
()
class
Sequential
(
nn
.
Sequential
):
class
Sequential
(
nn
.
Sequential
):
def
__init__
(
self
,
*
args
):
def
__init__
(
self
,
*
args
):
global
_records
add_record
(
id
(
self
),
{})
if
_records
is
not
None
:
_records
[
id
(
self
)]
=
{}
# no args need to be recorded
super
(
Sequential
,
self
).
__init__
(
*
args
)
super
(
Sequential
,
self
).
__init__
(
*
args
)
class
ModuleList
(
nn
.
ModuleList
):
class
ModuleList
(
nn
.
ModuleList
):
def
__init__
(
self
,
*
args
):
def
__init__
(
self
,
*
args
):
global
_records
add_record
(
id
(
self
),
{})
if
_records
is
not
None
:
_records
[
id
(
self
)]
=
{}
# no args need to be recorded
super
(
ModuleList
,
self
).
__init__
(
*
args
)
super
(
ModuleList
,
self
).
__init__
(
*
args
)
def
wrap_module
(
original_class
):
def
wrap_module
(
original_class
):
...
@@ -108,24 +112,132 @@ def wrap_module(original_class):
...
@@ -108,24 +112,132 @@ def wrap_module(original_class):
# Make copy of original __init__, so we can call it without recursion
# Make copy of original __init__, so we can call it without recursion
def
__init__
(
self
,
*
args
,
**
kws
):
def
__init__
(
self
,
*
args
,
**
kws
):
global
_records
if
_records
is
not
None
:
full_args
=
{}
full_args
=
{}
full_args
.
update
(
kws
)
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
args
[
i
]
full_args
[
argname_list
[
i
]]
=
args
[
i
]
_records
[
id
(
self
)]
=
full_args
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
return
original_class
# TODO: support different versions of pytorch
Identity
=
wrap_module
(
nn
.
Identity
)
Linear
=
wrap_module
(
nn
.
Linear
)
Conv1d
=
wrap_module
(
nn
.
Conv1d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
Conv2d
=
wrap_module
(
nn
.
Conv2d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
Conv3d
=
wrap_module
(
nn
.
Conv3d
)
ConvTranspose1d
=
wrap_module
(
nn
.
ConvTranspose1d
)
ConvTranspose2d
=
wrap_module
(
nn
.
ConvTranspose2d
)
ConvTranspose3d
=
wrap_module
(
nn
.
ConvTranspose3d
)
Threshold
=
wrap_module
(
nn
.
Threshold
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Hardtanh
=
wrap_module
(
nn
.
Hardtanh
)
Linear
=
wrap_module
(
nn
.
Linear
)
ReLU6
=
wrap_module
(
nn
.
ReLU6
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
Sigmoid
=
wrap_module
(
nn
.
Sigmoid
)
Tanh
=
wrap_module
(
nn
.
Tanh
)
Softmax
=
wrap_module
(
nn
.
Softmax
)
Softmax2d
=
wrap_module
(
nn
.
Softmax2d
)
LogSoftmax
=
wrap_module
(
nn
.
LogSoftmax
)
ELU
=
wrap_module
(
nn
.
ELU
)
SELU
=
wrap_module
(
nn
.
SELU
)
CELU
=
wrap_module
(
nn
.
CELU
)
GLU
=
wrap_module
(
nn
.
GLU
)
GELU
=
wrap_module
(
nn
.
GELU
)
Hardshrink
=
wrap_module
(
nn
.
Hardshrink
)
LeakyReLU
=
wrap_module
(
nn
.
LeakyReLU
)
LogSigmoid
=
wrap_module
(
nn
.
LogSigmoid
)
Softplus
=
wrap_module
(
nn
.
Softplus
)
Softshrink
=
wrap_module
(
nn
.
Softshrink
)
MultiheadAttention
=
wrap_module
(
nn
.
MultiheadAttention
)
PReLU
=
wrap_module
(
nn
.
PReLU
)
Softsign
=
wrap_module
(
nn
.
Softsign
)
Softmin
=
wrap_module
(
nn
.
Softmin
)
Tanhshrink
=
wrap_module
(
nn
.
Tanhshrink
)
RReLU
=
wrap_module
(
nn
.
RReLU
)
AvgPool1d
=
wrap_module
(
nn
.
AvgPool1d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
Identity
=
wrap_module
(
nn
.
Identity
)
AvgPool3d
=
wrap_module
(
nn
.
AvgPool3d
)
MaxPool1d
=
wrap_module
(
nn
.
MaxPool1d
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
MaxPool3d
=
wrap_module
(
nn
.
MaxPool3d
)
MaxUnpool1d
=
wrap_module
(
nn
.
MaxUnpool1d
)
MaxUnpool2d
=
wrap_module
(
nn
.
MaxUnpool2d
)
MaxUnpool3d
=
wrap_module
(
nn
.
MaxUnpool3d
)
FractionalMaxPool2d
=
wrap_module
(
nn
.
FractionalMaxPool2d
)
FractionalMaxPool3d
=
wrap_module
(
nn
.
FractionalMaxPool3d
)
LPPool1d
=
wrap_module
(
nn
.
LPPool1d
)
LPPool2d
=
wrap_module
(
nn
.
LPPool2d
)
LocalResponseNorm
=
wrap_module
(
nn
.
LocalResponseNorm
)
BatchNorm1d
=
wrap_module
(
nn
.
BatchNorm1d
)
BatchNorm2d
=
wrap_module
(
nn
.
BatchNorm2d
)
BatchNorm3d
=
wrap_module
(
nn
.
BatchNorm3d
)
InstanceNorm1d
=
wrap_module
(
nn
.
InstanceNorm1d
)
InstanceNorm2d
=
wrap_module
(
nn
.
InstanceNorm2d
)
InstanceNorm3d
=
wrap_module
(
nn
.
InstanceNorm3d
)
LayerNorm
=
wrap_module
(
nn
.
LayerNorm
)
GroupNorm
=
wrap_module
(
nn
.
GroupNorm
)
SyncBatchNorm
=
wrap_module
(
nn
.
SyncBatchNorm
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Dropout2d
=
wrap_module
(
nn
.
Dropout2d
)
Dropout3d
=
wrap_module
(
nn
.
Dropout3d
)
AlphaDropout
=
wrap_module
(
nn
.
AlphaDropout
)
FeatureAlphaDropout
=
wrap_module
(
nn
.
FeatureAlphaDropout
)
ReflectionPad1d
=
wrap_module
(
nn
.
ReflectionPad1d
)
ReflectionPad2d
=
wrap_module
(
nn
.
ReflectionPad2d
)
ReplicationPad2d
=
wrap_module
(
nn
.
ReplicationPad2d
)
ReplicationPad1d
=
wrap_module
(
nn
.
ReplicationPad1d
)
ReplicationPad3d
=
wrap_module
(
nn
.
ReplicationPad3d
)
CrossMapLRN2d
=
wrap_module
(
nn
.
CrossMapLRN2d
)
Embedding
=
wrap_module
(
nn
.
Embedding
)
EmbeddingBag
=
wrap_module
(
nn
.
EmbeddingBag
)
RNNBase
=
wrap_module
(
nn
.
RNNBase
)
RNN
=
wrap_module
(
nn
.
RNN
)
LSTM
=
wrap_module
(
nn
.
LSTM
)
GRU
=
wrap_module
(
nn
.
GRU
)
RNNCellBase
=
wrap_module
(
nn
.
RNNCellBase
)
RNNCell
=
wrap_module
(
nn
.
RNNCell
)
LSTMCell
=
wrap_module
(
nn
.
LSTMCell
)
GRUCell
=
wrap_module
(
nn
.
GRUCell
)
PixelShuffle
=
wrap_module
(
nn
.
PixelShuffle
)
Upsample
=
wrap_module
(
nn
.
Upsample
)
UpsamplingNearest2d
=
wrap_module
(
nn
.
UpsamplingNearest2d
)
UpsamplingBilinear2d
=
wrap_module
(
nn
.
UpsamplingBilinear2d
)
PairwiseDistance
=
wrap_module
(
nn
.
PairwiseDistance
)
AdaptiveMaxPool1d
=
wrap_module
(
nn
.
AdaptiveMaxPool1d
)
AdaptiveMaxPool2d
=
wrap_module
(
nn
.
AdaptiveMaxPool2d
)
AdaptiveMaxPool3d
=
wrap_module
(
nn
.
AdaptiveMaxPool3d
)
AdaptiveAvgPool1d
=
wrap_module
(
nn
.
AdaptiveAvgPool1d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
AdaptiveAvgPool3d
=
wrap_module
(
nn
.
AdaptiveAvgPool3d
)
TripletMarginLoss
=
wrap_module
(
nn
.
TripletMarginLoss
)
ZeroPad2d
=
wrap_module
(
nn
.
ZeroPad2d
)
ConstantPad1d
=
wrap_module
(
nn
.
ConstantPad1d
)
ConstantPad2d
=
wrap_module
(
nn
.
ConstantPad2d
)
ConstantPad3d
=
wrap_module
(
nn
.
ConstantPad3d
)
Bilinear
=
wrap_module
(
nn
.
Bilinear
)
CosineSimilarity
=
wrap_module
(
nn
.
CosineSimilarity
)
Unfold
=
wrap_module
(
nn
.
Unfold
)
Fold
=
wrap_module
(
nn
.
Fold
)
AdaptiveLogSoftmaxWithLoss
=
wrap_module
(
nn
.
AdaptiveLogSoftmaxWithLoss
)
TransformerEncoder
=
wrap_module
(
nn
.
TransformerEncoder
)
TransformerDecoder
=
wrap_module
(
nn
.
TransformerDecoder
)
TransformerEncoderLayer
=
wrap_module
(
nn
.
TransformerEncoderLayer
)
TransformerDecoderLayer
=
wrap_module
(
nn
.
TransformerDecoderLayer
)
Transformer
=
wrap_module
(
nn
.
Transformer
)
#LazyLinear = wrap_module(nn.LazyLinear)
#LazyConv1d = wrap_module(nn.LazyConv1d)
#LazyConv2d = wrap_module(nn.LazyConv2d)
#LazyConv3d = wrap_module(nn.LazyConv3d)
#LazyConvTranspose1d = wrap_module(nn.LazyConvTranspose1d)
#LazyConvTranspose2d = wrap_module(nn.LazyConvTranspose2d)
#LazyConvTranspose3d = wrap_module(nn.LazyConvTranspose3d)
Flatten
=
wrap_module
(
nn
.
Flatten
)
#Unflatten = wrap_module(nn.Unflatten)
Hardsigmoid
=
wrap_module
(
nn
.
Hardsigmoid
)
Hardswish
=
wrap_module
(
nn
.
Hardswish
)
#SiLU = wrap_module(nn.SiLU)
#TripletMarginWithDistanceLoss = wrap_module(nn.TripletMarginWithDistanceLoss)
#ChannelShuffle = wrap_module(nn.ChannelShuffle)
nni/retiarii/operation.py
View file @
192a807b
...
@@ -105,7 +105,7 @@ class PyTorchOperation(Operation):
...
@@ -105,7 +105,7 @@ class PyTorchOperation(Operation):
return
None
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
from
.converter.op_types
import
Type
from
.converter.op_types
import
Op
Type
Name
if
self
.
_to_class_name
()
is
not
None
:
if
self
.
_to_class_name
()
is
not
None
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
elif
self
.
type
.
startswith
(
'Function.'
):
elif
self
.
type
.
startswith
(
'Function.'
):
...
@@ -133,7 +133,7 @@ class PyTorchOperation(Operation):
...
@@ -133,7 +133,7 @@ class PyTorchOperation(Operation):
elif
self
.
type
==
'aten::add'
:
elif
self
.
type
==
'aten::add'
:
assert
len
(
inputs
)
==
2
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
elif
self
.
type
==
Type
.
MergedSlice
:
elif
self
.
type
==
Op
Type
Name
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
...
...
nni/retiarii/strategies/strategy.py
View file @
192a807b
...
@@ -4,5 +4,5 @@ from typing import List
...
@@ -4,5 +4,5 @@ from typing import List
class
BaseStrategy
(
abc
.
ABC
):
class
BaseStrategy
(
abc
.
ABC
):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
run
(
self
,
base_model
:
'Model'
,
applied_mutators
:
List
[
'Mutator'
]
,
trainer
:
'BaseTrainer'
)
->
None
:
def
run
(
self
,
base_model
:
'Model'
,
applied_mutators
:
List
[
'Mutator'
])
->
None
:
pass
pass
nni/retiarii/strategies/tpe_strategy.py
View file @
192a807b
...
@@ -42,7 +42,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -42,7 +42,7 @@ class TPEStrategy(BaseStrategy):
self
.
tpe_sampler
=
TPESampler
()
self
.
tpe_sampler
=
TPESampler
()
self
.
model_id
=
0
self
.
model_id
=
0
def
run
(
self
,
base_model
,
applied_mutators
,
trainer
):
def
run
(
self
,
base_model
,
applied_mutators
):
sample_space
=
[]
sample_space
=
[]
new_model
=
base_model
new_model
=
base_model
for
mutator
in
applied_mutators
:
for
mutator
in
applied_mutators
:
...
@@ -61,9 +61,6 @@ class TPEStrategy(BaseStrategy):
...
@@ -61,9 +61,6 @@ class TPEStrategy(BaseStrategy):
_logger
.
info
(
'mutate model...'
)
_logger
.
info
(
'mutate model...'
)
mutator
.
bind_sampler
(
self
.
tpe_sampler
)
mutator
.
bind_sampler
(
self
.
tpe_sampler
)
model
=
mutator
.
apply
(
model
)
model
=
mutator
.
apply
(
model
)
# get and apply training approach
_logger
.
info
(
'apply training approach...'
)
model
.
apply_trainer
(
trainer
[
'modulename'
],
trainer
[
'args'
])
# run models
# run models
submit_models
(
model
)
submit_models
(
model
)
wait_models
(
model
)
wait_models
(
model
)
...
...
nni/retiarii/trainer/interface.py
View file @
192a807b
import
abc
import
abc
import
inspect
import
inspect
from
..nn.pytorch
import
add_record
from
typing
import
*
from
typing
import
*
...
@@ -19,23 +18,6 @@ class BaseTrainer(abc.ABC):
...
@@ -19,23 +18,6 @@ class BaseTrainer(abc.ABC):
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
module
=
self
.
__class__
.
__module__
if
module
is
None
or
module
==
str
.
__class__
.
__module__
:
full_class_name
=
self
.
__class__
.
__name__
else
:
full_class_name
=
module
+
'.'
+
self
.
__class__
.
__name__
assert
not
kwargs
argname_list
=
list
(
inspect
.
signature
(
self
.
__class__
).
parameters
.
keys
())
assert
len
(
argname_list
)
==
len
(
args
),
'Error: {} not put input arguments in its super().__init__ function'
.
format
(
self
.
__class__
)
full_args
=
{}
for
i
,
arg_value
in
enumerate
(
args
):
if
argname_list
[
i
]
==
'model'
:
assert
i
==
0
continue
full_args
[
argname_list
[
i
]]
=
args
[
i
]
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
fit
(
self
)
->
None
:
def
fit
(
self
)
->
None
:
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
192a807b
...
@@ -10,6 +10,7 @@ from torchvision import datasets, transforms
...
@@ -10,6 +10,7 @@ from torchvision import datasets, transforms
import
nni
import
nni
from
..interface
import
BaseTrainer
from
..interface
import
BaseTrainer
from
...utils
import
register_trainer
def
get_default_transform
(
dataset
:
str
)
->
Any
:
def
get_default_transform
(
dataset
:
str
)
->
Any
:
...
@@ -41,7 +42,7 @@ def get_default_transform(dataset: str) -> Any:
...
@@ -41,7 +42,7 @@ def get_default_transform(dataset: str) -> Any:
# unsupported dataset, return None
# unsupported dataset, return None
return
None
return
None
@
register_trainer
()
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
class
PyTorchImageClassificationTrainer
(
BaseTrainer
):
"""
"""
Image classification trainer for PyTorch.
Image classification trainer for PyTorch.
...
@@ -78,9 +79,6 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -78,9 +79,6 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
only the key ``max_epochs`` is useful.
"""
"""
super
(
PyTorchImageClassificationTrainer
,
self
).
__init__
(
model
,
dataset_cls
,
dataset_kwargs
,
dataloader_kwargs
,
optimizer_cls
,
optimizer_kwargs
,
trainer_kwargs
)
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
model
=
model
self
.
model
=
model
if
self
.
_use_cuda
:
if
self
.
_use_cuda
:
...
...
nni/retiarii/utils.py
View file @
192a807b
import
traceback
import
inspect
from
.nn.pytorch
import
enable_record_args
,
get_records
,
disable_record_args
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
'Any'
:
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
'Any'
:
if
target
is
None
:
if
target
is
None
:
...
@@ -8,17 +7,79 @@ def import_(target: str, allow_none: bool = False) -> 'Any':
...
@@ -8,17 +7,79 @@ def import_(target: str, allow_none: bool = False) -> 'Any':
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
return
getattr
(
module
,
identifier
)
class
TraceClassArguments
:
def
__init__
(
self
):
self
.
recorded_arguments
=
None
def
__enter__
(
self
):
_records
=
{}
enable_record_args
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
tb
):
def
get_records
():
if
exc_type
is
not
None
:
global
_records
traceback
.
print_exception
(
exc_type
,
exc_value
,
tb
)
return
_records
# return False # uncomment to pass exception through
self
.
recorded_arguments
=
get_records
()
def
add_record
(
key
,
value
):
disable_record_args
()
"""
"""
global
_records
if
_records
is
not
None
:
assert
key
not
in
_records
,
'{} already in _records'
.
format
(
key
)
_records
[
key
]
=
value
def
_register_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
args
[
i
]
add_record
(
id
(
self
),
full_args
)
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
def
register_module
():
"""
Register a module.
"""
# use it as a decorator: @register_module()
def
_register
(
cls
):
m
=
_register_module
(
original_class
=
cls
)
return
m
return
_register
def
_register_trainer
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
# Make copy of original __init__, so we can call it without recursion
full_class_name
=
original_class
.
__module__
+
'.'
+
original_class
.
__name__
def
__init__
(
self
,
*
args
,
**
kws
):
full_args
=
{}
full_args
.
update
(
kws
)
for
i
,
arg
in
enumerate
(
args
):
# TODO: support both pytorch and tensorflow
from
.nn.pytorch
import
Module
if
isinstance
(
args
[
i
],
Module
):
# ignore the base model object
continue
full_args
[
argname_list
[
i
]]
=
args
[
i
]
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
orig_init
(
self
,
*
args
,
**
kws
)
# Call the original __init__
original_class
.
__init__
=
__init__
# Set the class' __init__ to the new one
return
original_class
def
register_trainer
():
def
_register
(
cls
):
m
=
_register_trainer
(
original_class
=
cls
)
return
m
return
_register
\ No newline at end of file
test/retiarii_test/darts/darts_model.py
0 → 100644
View file @
192a807b
from
collections
import
OrderedDict
from
typing
import
(
List
,
Optional
)
import
torch
import
torch.nn
as
torch_nn
#sys.path.append(str(Path(__file__).resolve().parents[2]))
import
ops
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register_module
class
AuxiliaryHead
(
nn
.
Module
):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
def
__init__
(
self
,
input_size
,
C
,
n_classes
):
""" assuming input size 7x7 or 8x8 """
assert
input_size
in
[
7
,
8
]
super
().
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(
inplace
=
True
),
nn
.
AvgPool2d
(
5
,
stride
=
input_size
-
5
,
padding
=
0
,
count_include_pad
=
False
),
# 2x2 out
nn
.
Conv2d
(
C
,
128
,
kernel_size
=
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
128
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
128
,
768
,
kernel_size
=
2
,
bias
=
False
),
# 1x1 out
nn
.
BatchNorm2d
(
768
),
nn
.
ReLU
(
inplace
=
True
)
)
self
.
linear
=
nn
.
Linear
(
768
,
n_classes
)
def
forward
(
self
,
x
):
out
=
self
.
net
(
x
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
# flatten
logits
=
self
.
linear
(
out
)
return
logits
@
register_module
()
class
Node
(
nn
.
Module
):
def
__init__
(
self
,
node_id
,
num_prev_nodes
,
channels
,
num_downsample_connect
):
super
().
__init__
()
self
.
ops
=
nn
.
ModuleList
()
choice_keys
=
[]
for
i
in
range
(
num_prev_nodes
):
stride
=
2
if
i
<
num_downsample_connect
else
1
choice_keys
.
append
(
"{}_p{}"
.
format
(
node_id
,
i
))
self
.
ops
.
append
(
nn
.
LayerChoice
([
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
nn
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
]))
self
.
drop_path
=
ops
.
DropPath
()
self
.
input_switch
=
nn
.
InputChoice
(
n_chosen
=
2
)
def
forward
(
self
,
prev_nodes
:
List
[
'Tensor'
])
->
'Tensor'
:
#assert self.ops.__len__() == len(prev_nodes)
#out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out
=
[]
for
i
,
op
in
enumerate
(
self
.
ops
):
out
.
append
(
op
(
prev_nodes
[
i
]))
#out = [self.drop_path(o) if o is not None else None for o in out]
return
self
.
input_switch
(
out
)
@
register_module
()
class
Cell
(
nn
.
Module
):
def
__init__
(
self
,
n_nodes
,
channels_pp
,
channels_p
,
channels
,
reduction_p
,
reduction
):
super
().
__init__
()
self
.
reduction
=
reduction
self
.
n_nodes
=
n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if
reduction_p
:
self
.
preproc0
=
ops
.
FactorizedReduce
(
channels_pp
,
channels
,
affine
=
False
)
else
:
self
.
preproc0
=
ops
.
StdConv
(
channels_pp
,
channels
,
1
,
1
,
0
,
affine
=
False
)
self
.
preproc1
=
ops
.
StdConv
(
channels_p
,
channels
,
1
,
1
,
0
,
affine
=
False
)
# generate dag
self
.
mutable_ops
=
nn
.
ModuleList
()
for
depth
in
range
(
2
,
self
.
n_nodes
+
2
):
self
.
mutable_ops
.
append
(
Node
(
"{}_n{}"
.
format
(
"reduce"
if
reduction
else
"normal"
,
depth
),
depth
,
channels
,
2
if
reduction
else
0
))
def
forward
(
self
,
s0
,
s1
):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors
=
[
self
.
preproc0
(
s0
),
self
.
preproc1
(
s1
)]
new_tensors
=
[]
for
node
in
self
.
mutable_ops
:
tmp
=
tensors
+
new_tensors
cur_tensor
=
node
(
tmp
)
new_tensors
.
append
(
cur_tensor
)
output
=
torch
.
cat
(
new_tensors
,
dim
=
1
)
return
output
@
register_module
()
class
CNN
(
nn
.
Module
):
def
__init__
(
self
,
input_size
,
in_channels
,
channels
,
n_classes
,
n_layers
,
n_nodes
=
4
,
stem_multiplier
=
3
,
auxiliary
=
False
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
channels
=
channels
self
.
n_classes
=
n_classes
self
.
n_layers
=
n_layers
self
.
aux_pos
=
2
*
n_layers
//
3
if
auxiliary
else
-
1
c_cur
=
stem_multiplier
*
self
.
channels
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
c_cur
,
3
,
1
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
c_cur
)
)
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp
,
channels_p
,
c_cur
=
c_cur
,
c_cur
,
channels
self
.
cells
=
nn
.
ModuleList
()
reduction_p
,
reduction
=
False
,
False
for
i
in
range
(
n_layers
):
reduction_p
,
reduction
=
reduction
,
False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if
i
in
[
n_layers
//
3
,
2
*
n_layers
//
3
]:
c_cur
*=
2
reduction
=
True
cell
=
Cell
(
n_nodes
,
channels_pp
,
channels_p
,
c_cur
,
reduction_p
,
reduction
)
self
.
cells
.
append
(
cell
)
c_cur_out
=
c_cur
*
n_nodes
channels_pp
,
channels_p
=
channels_p
,
c_cur_out
#if i == self.aux_pos:
# self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
linear
=
nn
.
Linear
(
channels_p
,
n_classes
)
def
forward
(
self
,
x
):
s0
=
s1
=
self
.
stem
(
x
)
#aux_logits = None
for
i
,
cell
in
enumerate
(
self
.
cells
):
s0
,
s1
=
s1
,
cell
(
s0
,
s1
)
#if i == self.aux_pos and self.training:
# aux_logits = self.aux_head(s1)
out
=
self
.
gap
(
s1
)
out
=
out
.
view
(
out
.
size
(
0
),
-
1
)
# flatten
logits
=
self
.
linear
(
out
)
#if aux_logits is not None:
# return logits, aux_logits
return
logits
def
drop_path_prob
(
self
,
p
):
for
module
in
self
.
modules
():
if
isinstance
(
module
,
ops
.
DropPath
):
module
.
p
=
p
if
__name__
==
'__main__'
:
base_model
=
CNN
(
32
,
3
,
16
,
10
,
8
)
test/retiarii_test/darts/ops.py
0 → 100644
View file @
192a807b
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
register_module
@
register_module
()
class
DropPath
(
nn
.
Module
):
def
__init__
(
self
,
p
=
0.
):
"""
Drop path with probability.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super
(
DropPath
,
self
).
__init__
()
self
.
p
=
p
def
forward
(
self
,
x
):
if
self
.
training
and
self
.
p
>
0.
:
keep_prob
=
1.
-
self
.
p
# per data point mask
mask
=
torch
.
zeros
((
x
.
size
(
0
),
1
,
1
,
1
),
device
=
x
.
device
).
bernoulli_
(
keep_prob
)
return
x
/
keep_prob
*
mask
return
x
@
register_module
()
class
PoolBN
(
nn
.
Module
):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def
__init__
(
self
,
pool_type
,
C
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
PoolBN
,
self
).
__init__
()
if
pool_type
.
lower
()
==
'max'
:
self
.
pool
=
nn
.
MaxPool2d
(
kernel_size
,
stride
,
padding
)
elif
pool_type
.
lower
()
==
'avg'
:
self
.
pool
=
nn
.
AvgPool2d
(
kernel_size
,
stride
,
padding
,
count_include_pad
=
False
)
else
:
raise
ValueError
()
self
.
bn
=
nn
.
BatchNorm2d
(
C
,
affine
=
affine
)
def
forward
(
self
,
x
):
out
=
self
.
pool
(
x
)
out
=
self
.
bn
(
out
)
return
out
@
register_module
()
class
StdConv
(
nn
.
Module
):
"""
Standard conv: ReLU - Conv - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
StdConv
,
self
).
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
register_module
()
class
FacConv
(
nn
.
Module
):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_length
,
stride
,
padding
,
affine
=
True
):
super
(
FacConv
,
self
).
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
(
kernel_length
,
1
),
stride
,
padding
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
(
1
,
kernel_length
),
stride
,
padding
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
register_module
()
class
DilConv
(
nn
.
Module
):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
dilation
,
affine
=
True
):
super
(
DilConv
,
self
).
__init__
()
self
.
net
=
nn
.
Sequential
(
nn
.
ReLU
(),
nn
.
Conv2d
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
dilation
,
groups
=
C_in
,
bias
=
False
),
nn
.
Conv2d
(
C_in
,
C_out
,
1
,
stride
=
1
,
padding
=
0
,
bias
=
False
),
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
register_module
()
class
SepConv
(
nn
.
Module
):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def
__init__
(
self
,
C_in
,
C_out
,
kernel_size
,
stride
,
padding
,
affine
=
True
):
super
(
SepConv
,
self
).
__init__
()
self
.
net
=
nn
.
Sequential
(
DilConv
(
C_in
,
C_in
,
kernel_size
,
stride
,
padding
,
dilation
=
1
,
affine
=
affine
),
DilConv
(
C_in
,
C_out
,
kernel_size
,
1
,
padding
,
dilation
=
1
,
affine
=
affine
)
)
def
forward
(
self
,
x
):
return
self
.
net
(
x
)
@
register_module
()
class
FactorizedReduce
(
nn
.
Module
):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def
__init__
(
self
,
C_in
,
C_out
,
affine
=
True
):
super
(
FactorizedReduce
,
self
).
__init__
()
self
.
relu
=
nn
.
ReLU
()
self
.
conv1
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
self
.
conv2
=
nn
.
Conv2d
(
C_in
,
C_out
//
2
,
1
,
stride
=
2
,
padding
=
0
,
bias
=
False
)
self
.
bn
=
nn
.
BatchNorm2d
(
C_out
,
affine
=
affine
)
def
forward
(
self
,
x
):
x
=
self
.
relu
(
x
)
out
=
torch
.
cat
([
self
.
conv1
(
x
),
self
.
conv2
(
x
[:,
:,
1
:,
1
:])],
dim
=
1
)
out
=
self
.
bn
(
out
)
return
out
test/retiarii_test/darts/test.py
0 → 100644
View file @
192a807b
import
json
import
os
import
sys
import
torch
from
pathlib
import
Path
from
nni.retiarii.experiment
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.strategies
import
TPEStrategy
from
nni.retiarii.trainer
import
PyTorchImageClassificationTrainer
from
darts_model
import
CNN
if
__name__
==
'__main__'
:
base_model
=
CNN
(
32
,
3
,
16
,
10
,
8
)
trainer
=
PyTorchImageClassificationTrainer
(
base_model
,
dataset_cls
=
"CIFAR10"
,
dataset_kwargs
=
{
"root"
:
"data/cifar10"
,
"download"
:
True
},
dataloader_kwargs
=
{
"batch_size"
:
32
},
optimizer_kwargs
=
{
"lr"
:
1e-3
},
trainer_kwargs
=
{
"max_epochs"
:
1
})
simple_startegy
=
TPEStrategy
()
exp
=
RetiariiExperiment
(
base_model
,
trainer
,
[],
simple_startegy
)
exp_config
=
RetiariiExeConfig
(
'local'
)
exp_config
.
experiment_name
=
'darts_search'
exp_config
.
trial_concurrency
=
2
exp_config
.
max_trial_number
=
10
exp_config
.
trial_gpu_number
=
1
exp_config
.
training_service
.
use_active_gpu
=
True
exp_config
.
training_service
.
gpu_indices
=
[
1
,
2
]
exp
.
run
(
exp_config
,
8081
,
debug
=
True
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment