Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6808708d
Unverified
Commit
6808708d
authored
Apr 09, 2021
by
Yuge Zhang
Committed by
GitHub
Apr 09, 2021
Browse files
[Retiarii] Nest `ValueChoice` in `LayerChoice` and dict/list in `ValueChoice` (#3508)
parent
b7062b5d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
196 additions
and
45 deletions
+196
-45
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+11
-14
nni/retiarii/graph.py
nni/retiarii/graph.py
+7
-5
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+29
-2
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+31
-16
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+3
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+109
-2
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+6
-6
No files found.
nni/retiarii/converter/graph_gen.py
View file @
6808708d
...
...
@@ -6,7 +6,7 @@ import re
import
torch
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..nn.pytorch
import
InputChoice
,
Placeholder
from
..operation
import
Cell
,
Operation
from
..serializer
import
get_init_parameters_or_fail
from
..utils
import
get_importable_name
...
...
@@ -343,7 +343,7 @@ class GraphConverter:
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)
)
:
elif
isinstance
(
submodule_obj
,
InputChoice
):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
else
:
# Graph already created, create Cell for it
...
...
@@ -536,16 +536,6 @@ class GraphConverter:
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
self
.
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
self
,
module
):
choices
=
[]
for
cand
in
list
(
module
):
cand_type
=
'__torch__.'
+
get_importable_name
(
cand
.
__class__
)
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
get_init_parameters_or_fail
(
cand
)})
return
{
'candidates'
:
choices
,
'label'
:
module
.
label
}
def
_handle_inputchoice
(
self
,
module
):
return
{
'n_candidates'
:
module
.
n_candidates
,
...
...
@@ -557,7 +547,8 @@ class GraphConverter:
def
_handle_valuechoice
(
self
,
module
):
return
{
'candidates'
:
module
.
candidates
,
'label'
:
module
.
label
'label'
:
module
.
label
,
'accessor'
:
module
.
_accessor
}
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
...
...
@@ -590,7 +581,13 @@ class GraphConverter:
if
original_type_name
in
MODULE_EXCEPT_LIST
:
pass
# do nothing
elif
original_type_name
==
OpTypeName
.
LayerChoice
:
m_attrs
=
self
.
_handle_layerchoice
(
module
)
graph
=
Graph
(
ir_model
,
-
100
,
module_name
,
_internal
=
True
)
# graph_id is not used now
candidate_name_list
=
[
f
'layerchoice_
{
module
.
label
}
_
{
cand_name
}
'
for
cand_name
in
module
.
names
]
for
cand_name
,
cand
in
zip
(
candidate_name_list
,
module
):
cand_type
=
'__torch__.'
+
get_importable_name
(
cand
.
__class__
)
graph
.
add_node
(
cand_name
,
cand_type
,
get_init_parameters_or_fail
(
cand
))
graph
.
_register
()
return
graph
,
{
'mutation'
:
'layerchoice'
,
'label'
:
module
.
label
,
'candidates'
:
candidate_name_list
}
elif
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
self
.
_handle_inputchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
ValueChoice
:
...
...
nni/retiarii/graph.py
View file @
6808708d
...
...
@@ -144,15 +144,17 @@ class Model:
for
graph_name
,
graph_data
in
ir
.
items
():
if
graph_name
!=
'_evaluator'
:
Graph
.
_load
(
model
,
graph_name
,
graph_data
).
_register
()
model
.
evaluator
=
Evaluator
.
_load_with_type
(
ir
[
'_evaluator'
][
'__type__'
],
ir
[
'_evaluator'
])
if
'_evaluator'
in
ir
:
model
.
evaluator
=
Evaluator
.
_load_with_type
(
ir
[
'_evaluator'
][
'__type__'
],
ir
[
'_evaluator'
])
return
model
def
_dump
(
self
)
->
Any
:
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
[
'_evaluator'
]
=
{
'__type__'
:
get_importable_name
(
self
.
evaluator
.
__class__
),
**
self
.
evaluator
.
_dump
()
}
if
self
.
evaluator
is
not
None
:
ret
[
'_evaluator'
]
=
{
'__type__'
:
get_importable_name
(
self
.
evaluator
.
__class__
),
**
self
.
evaluator
.
_dump
()
}
return
ret
def
get_nodes
(
self
)
->
Iterable
[
'Node'
]:
...
...
nni/retiarii/nn/pytorch/api.py
View file @
6808708d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
warnings
from
collections
import
OrderedDict
from
typing
import
Any
,
List
,
Union
,
Dict
import
warnings
import
torch
import
torch.nn
as
nn
...
...
@@ -268,6 +269,7 @@ class ValueChoice(Translatable, nn.Module):
super
().
__init__
()
self
.
candidates
=
candidates
self
.
_label
=
label
if
label
is
not
None
else
f
'valuechoice_
{
uid
()
}
'
self
.
_accessor
=
[]
@
property
def
label
(
self
):
...
...
@@ -279,11 +281,36 @@ class ValueChoice(Translatable, nn.Module):
def
_translate
(
self
):
# Will function as a value when used in serializer.
return
self
.
candidates
[
0
]
return
self
.
access
(
self
.
candidates
[
0
]
)
def
__repr__
(
self
):
return
f
'ValueChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
def
access
(
self
,
value
):
if
not
self
.
_accessor
:
return
value
try
:
v
=
value
for
a
in
self
.
_accessor
:
v
=
v
[
a
]
except
KeyError
:
raise
KeyError
(
''
.
join
([
f
'[
{
a
}
]'
for
a
in
self
.
_accessor
])
+
f
' does not work on
{
value
}
'
)
return
v
def
__getitem__
(
self
,
item
):
"""
Get a sub-element of value choice.
The underlying implementation is to clone the current instance, and append item to "accessor", which records all
the history getitem calls. For example, when accessor is ``[a, b, c]``, the value choice will return ``vc[a][b][c]``
where ``vc`` is the original value choice.
"""
access
=
copy
.
deepcopy
(
self
)
access
.
_accessor
.
append
(
item
)
for
candidate
in
self
.
candidates
:
access
.
access
(
candidate
)
return
access
@
basic_unit
class
Placeholder
(
nn
.
Module
):
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
6808708d
...
...
@@ -4,7 +4,7 @@
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
...mutator
import
Mutator
from
...graph
import
Model
,
Node
from
...graph
import
Cell
,
Model
,
Node
from
.api
import
ValueChoice
...
...
@@ -14,13 +14,23 @@ class LayerChoiceMutator(Mutator):
self
.
nodes
=
nodes
def
mutate
(
self
,
model
):
n_candidates
=
len
(
self
.
nodes
[
0
].
operation
.
parameters
[
'candidates'
])
indices
=
list
(
range
(
n_candidates
))
chosen_index
=
self
.
choice
(
indices
)
candidates
=
self
.
nodes
[
0
].
operation
.
parameters
[
'candidates'
]
chosen
=
self
.
choice
(
candidates
)
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
chosen_cand
=
node
.
operation
.
parameters
[
'candidates'
][
chosen_index
]
target
.
update_operation
(
chosen_cand
[
'type'
],
chosen_cand
[
'parameters'
])
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target
=
model
.
graphs
[
node
.
operation
.
cell_name
]
chosen_node
=
target
.
get_node_by_name
(
chosen
)
assert
chosen_node
is
not
None
target
.
add_edge
((
target
.
input_node
,
0
),
(
chosen_node
,
None
))
target
.
add_edge
((
chosen_node
,
None
),
(
target
.
output_node
,
None
))
model
.
get_node_by_name
(
node
.
name
).
update_operation
(
Cell
(
node
.
operation
.
cell_name
))
# remove redundant nodes
for
rm_node
in
target
.
hidden_nodes
:
if
rm_node
.
name
!=
chosen_node
.
name
:
rm_node
.
remove
()
class
InputChoiceMutator
(
Mutator
):
...
...
@@ -61,20 +71,14 @@ class ParameterChoiceMutator(Mutator):
def
mutate
(
self
,
model
):
chosen
=
self
.
choice
(
self
.
candidates
)
for
node
,
argname
in
self
.
nodes
:
chosen_value
=
node
.
operation
.
parameters
[
argname
].
access
(
chosen
)
target
=
model
.
get_node_by_name
(
node
.
name
)
target
.
update_operation
(
target
.
operation
.
type
,
{
**
target
.
operation
.
parameters
,
argname
:
chosen
})
target
.
update_operation
(
target
.
operation
.
type
,
{
**
target
.
operation
.
parameters
,
argname
:
chosen
_value
})
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
applied_mutators
=
[]
lc_nodes
=
_group_by_label
(
model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.api.LayerChoice'
))
for
node_list
in
lc_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
len
(
node
.
operation
.
parameters
[
'candidates'
]),
node_list
)),
\
'Layer choice with the same label must have the same number of candidates.'
mutator
=
LayerChoiceMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
ic_nodes
=
_group_by_label
(
model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.api.InputChoice'
))
for
node_list
in
ic_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'n_candidates'
],
node_list
))
and
\
...
...
@@ -99,9 +103,20 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
for
node_list
in
pc_nodes
:
assert
_is_all_equal
([
node
.
operation
.
parameters
[
name
].
candidates
for
node
,
name
in
node_list
]),
\
'Value choice with the same label must have the same candidates.'
mutator
=
ParameterChoiceMutator
(
node_list
,
node_list
[
0
][
0
].
operation
.
parameters
[
node_list
[
0
][
1
]].
candidates
)
first_node
,
first_argname
=
node_list
[
0
]
mutator
=
ParameterChoiceMutator
(
node_list
,
first_node
.
operation
.
parameters
[
first_argname
].
candidates
)
applied_mutators
.
append
(
mutator
)
# apply layer choice at last as it will delete some nodes
lc_nodes
=
_group_by_label
(
filter
(
lambda
d
:
d
.
operation
.
parameters
.
get
(
'mutation'
)
==
'layerchoice'
,
model
.
get_nodes_by_type
(
'_cell'
)))
for
node_list
in
lc_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
len
(
node
.
operation
.
parameters
[
'candidates'
]),
node_list
)),
\
'Layer choice with the same label must have the same number of candidates.'
mutator
=
LayerChoiceMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
if
applied_mutators
:
return
applied_mutators
return
None
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
6808708d
...
...
@@ -69,6 +69,9 @@ class PrimConstant(PyTorchOperation):
elif
self
.
parameters
[
'type'
]
==
'Device'
:
value
=
self
.
parameters
[
'value'
]
return
f
'
{
output
}
= torch.device("
{
value
}
")'
elif
self
.
parameters
[
'type'
]
in
(
'dict'
,
'list'
,
'tuple'
):
# TODO: prim::TupleIndex is not supported yet
return
f
'
{
output
}
=
{
repr
(
self
.
parameters
[
"value"
])
}
'
else
:
raise
RuntimeError
(
f
'unsupported type of prim::Constant:
{
self
.
parameters
[
"type"
]
}
'
)
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
6808708d
import
random
import
unittest
from
collections
import
Counter
import
nni.retiarii.nn.pytorch
as
nn
import
torch
...
...
@@ -252,6 +253,30 @@ class TestHighLevelAPI(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
def
test_value_choice_in_layer_choice
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
linear
=
nn
.
LayerChoice
([
nn
.
Linear
(
3
,
nn
.
ValueChoice
([
10
,
20
])),
nn
.
Linear
(
3
,
nn
.
ValueChoice
([
30
,
40
]))
])
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
3
)
sz_counter
=
Counter
()
sampler
=
RandomSampler
()
for
i
in
range
(
100
):
model_new
=
model
for
mutator
in
mutators
:
model_new
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model_new
)
sz_counter
[
self
.
_get_converted_pytorch_model
(
model_new
)(
torch
.
randn
(
1
,
3
)).
size
(
1
)]
+=
1
self
.
assertEqual
(
len
(
sz_counter
),
4
)
def
test_shared
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
shared
=
True
):
...
...
@@ -284,12 +309,94 @@ class TestHighLevelAPI(unittest.TestCase):
# repeat test. Expectation: sometimes succeeds, sometimes fails.
failed_count
=
0
for
i
in
range
(
30
):
model_new
=
model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
model
_new
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
_new
)
self
.
assertEqual
(
sampler
.
counter
,
2
*
(
i
+
1
))
try
:
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
_get_converted_pytorch_model
(
model
_new
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
except
RuntimeError
:
failed_count
+=
1
self
.
assertGreater
(
failed_count
,
0
)
self
.
assertLess
(
failed_count
,
30
)
def
test_valuechoice_access
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
vc
=
nn
.
ValueChoice
([(
6
,
3
),
(
8
,
5
)])
self
.
conv
=
nn
.
Conv2d
(
3
,
vc
[
0
],
kernel_size
=
vc
[
1
])
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
input
=
torch
.
randn
(
1
,
3
,
5
,
5
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutators
[
0
].
apply
(
model
))(
input
).
size
(),
torch
.
Size
([
1
,
6
,
3
,
3
]))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutators
[
0
].
apply
(
model
))(
input
).
size
(),
torch
.
Size
([
1
,
8
,
1
,
1
]))
class
Net2
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
choices
=
[
{
'b'
:
[
3
],
'bp'
:
[
6
]},
{
'b'
:
[
6
],
'bp'
:
[
12
]}
]
self
.
conv
=
nn
.
Conv2d
(
3
,
nn
.
ValueChoice
(
choices
,
label
=
'a'
)[
'b'
][
0
],
1
)
self
.
conv1
=
nn
.
Conv2d
(
nn
.
ValueChoice
(
choices
,
label
=
'a'
)[
'bp'
][
0
],
3
,
1
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
self
.
conv1
(
torch
.
cat
((
x
,
x
),
1
))
model
=
self
.
_convert_to_ir
(
Net2
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
input
=
torch
.
randn
(
1
,
3
,
5
,
5
)
self
.
_get_converted_pytorch_model
(
mutators
[
0
].
apply
(
model
))(
input
)
def
test_valuechoice_access_functional
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
dropout_rate
=
nn
.
ValueChoice
([[
0.
,],
[
1.
,]])
def
forward
(
self
,
x
):
return
F
.
dropout
(
x
,
self
.
dropout_rate
()[
0
])
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
def
test_valuechoice_access_functional_expression
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
dropout_rate
=
nn
.
ValueChoice
([[
1.05
,],
[
1.1
,]])
def
forward
(
self
,
x
):
# if expression failed, the exception would be:
# ValueError: dropout probability has to be between 0 and 1, but got 1.05
return
F
.
dropout
(
x
,
self
.
dropout_rate
()[
0
]
-
.
1
)
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
test/ut/retiarii/test_strategy.py
View file @
6808708d
...
...
@@ -62,11 +62,11 @@ class Net(nn.Module):
self
.
fc1
=
nn
.
LayerChoice
([
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
,
bias
=
True
),
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
,
bias
=
False
)
])
]
,
label
=
'fc1'
)
self
.
fc2
=
nn
.
LayerChoice
([
nn
.
Linear
(
hidden_size
,
10
,
bias
=
False
),
nn
.
Linear
(
hidden_size
,
10
,
bias
=
True
)
])
]
,
label
=
'fc2'
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
...
...
@@ -97,8 +97,8 @@ def test_grid_search():
selection
=
set
()
for
model
in
engine
.
models
:
selection
.
add
((
model
.
g
et_node_by_name
(
'_model__fc1'
)
.
operation
.
parameters
[
'bias'
],
model
.
g
et_node_by_name
(
'_model__fc2'
)
.
operation
.
parameters
[
'bias'
]
model
.
g
raphs
[
'_model__fc1'
].
hidden_nodes
[
0
]
.
operation
.
parameters
[
'bias'
],
model
.
g
raphs
[
'_model__fc2'
].
hidden_nodes
[
0
]
.
operation
.
parameters
[
'bias'
]
))
assert
len
(
selection
)
==
4
_reset_execution_engine
()
...
...
@@ -113,8 +113,8 @@ def test_random_search():
selection
=
set
()
for
model
in
engine
.
models
:
selection
.
add
((
model
.
g
et_node_by_name
(
'_model__fc1'
)
.
operation
.
parameters
[
'bias'
],
model
.
g
et_node_by_name
(
'_model__fc2'
)
.
operation
.
parameters
[
'bias'
]
model
.
g
raphs
[
'_model__fc1'
].
hidden_nodes
[
0
]
.
operation
.
parameters
[
'bias'
],
model
.
g
raphs
[
'_model__fc2'
].
hidden_nodes
[
0
]
.
operation
.
parameters
[
'bias'
]
))
assert
len
(
selection
)
==
4
_reset_execution_engine
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment