Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ba771871
Unverified
Commit
ba771871
authored
Mar 23, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 23, 2022
Browse files
Support ValueChoice as depth in Repeat (#4598)
parent
c5e3bad9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
157 additions
and
43 deletions
+157
-43
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+2
-1
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+25
-0
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+53
-18
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+10
-11
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+67
-13
No files found.
nni/retiarii/converter/graph_gen.py
View file @
ba771871
...
...
@@ -660,8 +660,9 @@ class GraphConverter:
attrs
=
{
'mutation'
:
'repeat'
,
'label'
:
module
.
label
,
'depth'
:
module
.
depth_choice
,
'max_depth'
:
module
.
max_depth
,
'min_depth'
:
module
.
min_depth
,
'max_depth'
:
module
.
max_depth
}
return
ir_graph
,
attrs
...
...
nni/retiarii/nn/pytorch/api.py
View file @
ba771871
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import
math
import
itertools
import
operator
import
warnings
from
typing
import
Any
,
List
,
Union
,
Dict
,
Optional
,
Callable
,
Iterable
,
NoReturn
,
TypeVar
...
...
@@ -439,6 +440,30 @@ class ValueChoiceX(Translatable):
# values are not used
return
self
.
_evaluate
(
iter
([]),
True
)
def
all_options
(
self
)
->
Iterable
[
Any
]:
"""Explore all possibilities of a value choice.
"""
# Record all inner choices: label -> candidates, no duplicates.
dedup_inner_choices
:
Dict
[
str
,
List
[
Any
]]
=
{}
# All labels of leaf nodes on tree, possibly duplicates.
all_labels
:
List
[
str
]
=
[]
for
choice
in
self
.
inner_choices
():
all_labels
.
append
(
choice
.
label
)
if
choice
.
label
in
dedup_inner_choices
:
if
choice
.
candidates
!=
dedup_inner_choices
[
choice
.
label
]:
# check for choice with the same label
raise
ValueError
(
f
'"
{
choice
.
candidates
}
" is not equal to "
{
dedup_inner_choices
[
choice
.
label
]
}
", '
f
'but they share the same label:
{
choice
.
label
}
'
)
else
:
dedup_inner_choices
[
choice
.
label
]
=
choice
.
candidates
dedup_labels
,
dedup_candidates
=
list
(
dedup_inner_choices
.
keys
()),
list
(
dedup_inner_choices
.
values
())
for
chosen
in
itertools
.
product
(
*
dedup_candidates
):
chosen
=
dict
(
zip
(
dedup_labels
,
chosen
))
yield
self
.
evaluate
([
chosen
[
label
]
for
label
in
all_labels
])
def
evaluate
(
self
,
values
:
Iterable
[
Any
])
->
Any
:
"""
Evaluate the result of this group.
...
...
nni/retiarii/nn/pytorch/component.py
View file @
ba771871
import
copy
import
warnings
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
nni.retiarii.utils
import
NoContextError
,
STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
from
.api
import
LayerChoice
,
ValueChoice
,
ValueChoiceX
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
...
...
@@ -30,7 +31,7 @@ class Repeat(Mutable):
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least ``min`` times and at most ``max`` times.
If a ValueChoice, it should choose from a series of positive integers.
Examples
--------
...
...
@@ -51,6 +52,10 @@ class Repeat(Mutable):
we need a factory function that accepts index (0, 1, 2, ...) and returns the module of the ``index``-th layer. ::
self.blocks = nn.Repeat(lambda index: nn.LayerChoice([...], label=f'layer{index}'), (1, 3))
Depth can be a ValueChoice to support arbitrary depth candidate list. ::
self.blocks = nn.Repeat(Block(), nn.ValueChoice([1, 3, 5]))
"""
@
classmethod
...
...
@@ -59,17 +64,26 @@ class Repeat(Mutable):
List
[
Callable
[[
int
],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
repeat
=
get_fixed_value
(
label
)
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'blocks.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
'blocks'
})
return
result
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
ValueChoice
],
*
,
label
:
Optional
[
str
]
=
None
):
if
isinstance
(
depth
,
tuple
):
# we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
depth
=
get_fixed_value
(
label
)
if
isinstance
(
depth
,
int
):
# if depth is a valuechoice, it should be already an int
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
depth
))
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'blocks.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
'blocks'
})
return
result
raise
NoContextError
(
f
'Not in fixed mode, or
{
depth
}
not an integer.'
)
def
__init__
(
self
,
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
...
...
@@ -78,15 +92,32 @@ class Repeat(Mutable):
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
self
.
max_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
1
]
if
isinstance
(
depth
,
ValueChoiceX
):
if
label
is
not
None
:
warnings
.
warn
(
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.'
,
RuntimeWarning
)
self
.
depth_choice
=
depth
all_values
=
list
(
self
.
depth_choice
.
all_options
())
self
.
min_depth
=
min
(
all_values
)
self
.
max_depth
=
max
(
all_values
)
elif
isinstance
(
depth
,
tuple
):
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
self
.
max_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
1
]
self
.
depth_choice
=
ValueChoice
(
list
(
range
(
self
.
min_depth
,
self
.
max_depth
+
1
)),
label
=
label
)
elif
isinstance
(
depth
,
int
):
self
.
min_depth
=
self
.
max_depth
=
depth
self
.
depth_choice
=
depth
else
:
raise
TypeError
(
f
'Unsupported "depth" type:
{
type
(
depth
)
}
'
)
assert
self
.
max_depth
>=
self
.
min_depth
>
0
self
.
blocks
=
nn
.
ModuleList
(
self
.
_replicate_and_instantiate
(
blocks
,
self
.
max_depth
))
@
property
def
label
(
self
):
return
self
.
_
label
return
self
.
depth_choice
.
label
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
...
...
@@ -107,6 +138,10 @@ class Repeat(Mutable):
blocks
=
[
b
(
i
)
for
i
,
b
in
enumerate
(
blocks
)]
return
blocks
def
__getitem__
(
self
,
index
):
# shortcut for blocks[index]
return
self
.
blocks
[
index
]
class
NasBench201Cell
(
nn
.
Module
):
"""
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
ba771871
...
...
@@ -14,7 +14,7 @@ from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from
nni.retiarii.utils
import
uid
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
ValueChoiceX
,
Placeholder
from
.component
import
Repeat
,
NasBench101Cell
,
NasBench101Mutator
from
.component
import
NasBench101Cell
,
NasBench101Mutator
class
LayerChoiceMutator
(
Mutator
):
...
...
@@ -144,14 +144,15 @@ class RepeatMutator(Mutator):
return
chain
def
mutate
(
self
,
model
):
min_depth
=
self
.
nodes
[
0
].
operation
.
parameters
[
'min_depth'
]
max_depth
=
self
.
nodes
[
0
].
operation
.
parameters
[
'max_depth'
]
if
min_depth
<
max_depth
:
chosen_depth
=
self
.
choice
(
list
(
range
(
min_depth
,
max_depth
+
1
)))
for
node
in
self
.
nodes
:
# the logic here is similar to layer choice. We find cell attached to each node.
target
:
Graph
=
model
.
graphs
[
node
.
operation
.
cell_name
]
chain
=
self
.
_retrieve_chain_from_graph
(
target
)
# and we get the chosen depth (by value choice)
node_in_model
=
model
.
get_node_by_name
(
node
.
name
)
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth
=
node_in_model
.
operation
.
parameters
[
'depth'
]
for
edge
in
chain
[
chosen_depth
-
1
].
outgoing_edges
:
edge
.
remove
()
target
.
add_edge
((
chain
[
chosen_depth
-
1
],
None
),
(
target
.
output_node
,
None
))
...
...
@@ -184,6 +185,8 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
# `pc_nodes` are arguments of basic units. They can be compositions.
pc_nodes
:
List
[
Tuple
[
Node
,
str
,
ValueChoiceX
]]
=
[]
for
node
in
model
.
get_nodes
():
# arguments used in operators like Conv2d
# argument `valuechoice` used in generated repeat cell
for
name
,
choice
in
node
.
operation
.
parameters
.
items
():
if
isinstance
(
choice
,
ValueChoiceX
):
# e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
...
...
@@ -219,9 +222,10 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
repeat_nodes
=
_group_by_label
(
filter
(
lambda
d
:
d
.
operation
.
parameters
.
get
(
'mutation'
)
==
'repeat'
,
model
.
get_nodes_by_type
(
'_cell'
)))
for
node_list
in
repeat_nodes
:
# this check is not completely reliable, because it only checks max and min
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'max_depth'
],
node_list
))
and
\
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'min_depth'
],
node_list
)),
\
'Repeat with the same label must have the same
number of
candidates.'
'Repeat with the same label must have the same candidates.'
mutator
=
RepeatMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
...
...
@@ -303,11 +307,6 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if
isinstance
(
module
,
ValueChoice
):
node
=
graph
.
add_node
(
name
,
'ValueChoice'
,
{
'candidates'
:
module
.
candidates
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
Repeat
)
and
module
.
min_depth
<=
module
.
max_depth
:
node
=
graph
.
add_node
(
name
,
'Repeat'
,
{
'candidates'
:
list
(
range
(
module
.
min_depth
,
module
.
max_depth
+
1
))
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
NasBench101Cell
):
node
=
graph
.
add_node
(
name
,
'NasBench101Cell'
,
{
'max_num_edges'
:
module
.
max_num_edges
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
ba771871
...
...
@@ -66,6 +66,8 @@ def _apply_all_mutators(model, mutators, samplers):
class
GraphIR
(
unittest
.
TestCase
):
# graph engine will have an extra mutator for parameter choices
value_choice_incr
=
1
# graph engine has an extra mutator to apply the depth choice to nodes
repeat_incr
=
1
def
_convert_to_ir
(
self
,
model
):
script_module
=
torch
.
jit
.
script
(
model
)
...
...
@@ -578,14 +580,39 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
model3
=
mutator
.
apply
(
model
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
zeros
(
1
,
16
))
==
3
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
zeros
(
1
,
16
))
==
4
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model3
)(
torch
.
zeros
(
1
,
16
))
==
5
).
all
())
self
.
assertEqual
(
len
(
mutators
),
1
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
samplers
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
for
target
in
[
3
,
4
,
5
]:
new_model
=
_apply_all_mutators
(
model
,
mutators
,
samplers
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
16
))
==
target
).
all
())
def
test_repeat_static
(
self
):
class
AddOne
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
1
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
lambda
index
:
nn
.
LayerChoice
([
AddOne
(),
nn
.
Identity
()]),
4
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
4
)
sampler
=
RandomSampler
()
result
=
[]
for
_
in
range
(
50
):
new_model
=
model
for
mutator
in
mutators
:
new_model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
new_model
)
result
.
append
(
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
1
)).
item
())
for
x
in
[
1
,
2
,
3
]:
self
.
assertIn
(
float
(
x
),
result
)
def
test_repeat_complex
(
self
):
class
AddOne
(
nn
.
Module
):
...
...
@@ -602,8 +629,8 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
2
)
self
.
assertEqual
(
set
([
mutator
.
label
for
mutator
in
mutators
]),
{
'lc'
,
'rep'
})
self
.
assertEqual
(
len
(
mutators
),
2
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
self
.
assertEqual
(
set
([
mutator
.
label
for
mutator
in
mutators
if
mutator
.
label
is
not
None
]),
{
'lc'
,
'rep'
})
sampler
=
RandomSampler
()
for
_
in
range
(
10
):
...
...
@@ -624,7 +651,7 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
4
)
self
.
assertEqual
(
len
(
mutators
),
4
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
result
=
[]
for
_
in
range
(
20
):
...
...
@@ -635,6 +662,27 @@ class GraphIR(unittest.TestCase):
self
.
assertIn
(
1.
,
result
)
def
test_repeat_valuechoice
(
self
):
class
AddOne
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
1
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
AddOne
(),
nn
.
ValueChoice
([
1
,
3
,
5
]))
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
samplers
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
for
target
in
[
1
,
3
,
5
]:
new_model
=
_apply_all_mutators
(
model
,
mutators
,
samplers
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
16
))
==
target
).
all
())
def
test_repeat_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
...
...
@@ -647,11 +695,11 @@ class GraphIR(unittest.TestCase):
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_
sampler
(
EnumerateSampler
()
)
sampler
s
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
inp
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
4
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutator
.
apply
(
model
))
model_new
=
self
.
_get_converted_pytorch_model
(
_apply_all_mutators
(
model
,
mutators
,
samplers
))
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
...
...
@@ -778,6 +826,7 @@ class GraphIR(unittest.TestCase):
class
Python
(
GraphIR
):
# Python engine doesn't have the extra mutator
value_choice_incr
=
0
repeat_incr
=
0
def
_get_converted_pytorch_model
(
self
,
model_ir
):
mutation
=
{
mut
.
mutator
.
label
:
_unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model_ir
.
history
}
...
...
@@ -891,6 +940,8 @@ class Shared(unittest.TestCase):
elif
i
==
2
:
assert
choice
.
candidates
==
[
5
,
6
]
assert
d
.
evaluate
([
2
,
3
,
5
])
==
20
expect
=
[
x
+
y
+
3
*
z
for
x
in
[
1
,
2
]
for
y
in
[
3
,
4
]
for
z
in
[
5
,
6
]]
assert
list
(
d
.
all_options
())
==
expect
a
=
nn
.
ValueChoice
([
'cat'
,
'dog'
])
b
=
nn
.
ValueChoice
([
'milk'
,
'coffee'
])
...
...
@@ -967,6 +1018,9 @@ class Shared(unittest.TestCase):
lst
=
[
value
if
choice
.
label
==
'value'
else
divisor
for
choice
in
result
.
inner_choices
()]
assert
result
.
evaluate
(
lst
)
==
original_make_divisible
(
value
,
divisor
)
assert
len
(
list
(
result
.
all_options
()))
==
30
assert
max
(
result
.
all_options
())
==
135
def
test_valuechoice_in_evaluator
(
self
):
def
foo
():
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment