Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ba771871
"docs/en_US/vscode:/vscode.git/clone" did not exist on "f1ce1648b24d2668c2eb8fa02b158a7b6da80ea4"
Unverified
Commit
ba771871
authored
Mar 23, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 23, 2022
Browse files
Support ValueChoice as depth in Repeat (#4598)
parent
c5e3bad9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
157 additions
and
43 deletions
+157
-43
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+2
-1
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+25
-0
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+53
-18
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+10
-11
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+67
-13
No files found.
nni/retiarii/converter/graph_gen.py
View file @
ba771871
...
...
@@ -660,8 +660,9 @@ class GraphConverter:
attrs
=
{
'mutation'
:
'repeat'
,
'label'
:
module
.
label
,
'depth'
:
module
.
depth_choice
,
'max_depth'
:
module
.
max_depth
,
'min_depth'
:
module
.
min_depth
,
'max_depth'
:
module
.
max_depth
}
return
ir_graph
,
attrs
...
...
nni/retiarii/nn/pytorch/api.py
View file @
ba771871
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import
math
import
itertools
import
operator
import
warnings
from
typing
import
Any
,
List
,
Union
,
Dict
,
Optional
,
Callable
,
Iterable
,
NoReturn
,
TypeVar
...
...
@@ -439,6 +440,30 @@ class ValueChoiceX(Translatable):
# values are not used
return
self
.
_evaluate
(
iter
([]),
True
)
def
all_options
(
self
)
->
Iterable
[
Any
]:
"""Explore all possibilities of a value choice.
"""
# Record all inner choices: label -> candidates, no duplicates.
dedup_inner_choices
:
Dict
[
str
,
List
[
Any
]]
=
{}
# All labels of leaf nodes on tree, possibly duplicates.
all_labels
:
List
[
str
]
=
[]
for
choice
in
self
.
inner_choices
():
all_labels
.
append
(
choice
.
label
)
if
choice
.
label
in
dedup_inner_choices
:
if
choice
.
candidates
!=
dedup_inner_choices
[
choice
.
label
]:
# check for choice with the same label
raise
ValueError
(
f
'"
{
choice
.
candidates
}
" is not equal to "
{
dedup_inner_choices
[
choice
.
label
]
}
", '
f
'but they share the same label:
{
choice
.
label
}
'
)
else
:
dedup_inner_choices
[
choice
.
label
]
=
choice
.
candidates
dedup_labels
,
dedup_candidates
=
list
(
dedup_inner_choices
.
keys
()),
list
(
dedup_inner_choices
.
values
())
for
chosen
in
itertools
.
product
(
*
dedup_candidates
):
chosen
=
dict
(
zip
(
dedup_labels
,
chosen
))
yield
self
.
evaluate
([
chosen
[
label
]
for
label
in
all_labels
])
def
evaluate
(
self
,
values
:
Iterable
[
Any
])
->
Any
:
"""
Evaluate the result of this group.
...
...
nni/retiarii/nn/pytorch/component.py
View file @
ba771871
import
copy
import
warnings
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
import
torch
import
torch.nn
as
nn
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
nni.retiarii.utils
import
NoContextError
,
STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
from
.api
import
LayerChoice
,
ValueChoice
,
ValueChoiceX
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
...
...
@@ -30,7 +31,7 @@ class Repeat(Mutable):
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least ``min`` times and at most ``max`` times.
If a ValueChoice, it should choose from a series of positive integers.
Examples
--------
...
...
@@ -51,6 +52,10 @@ class Repeat(Mutable):
we need a factory function that accepts index (0, 1, 2, ...) and returns the module of the ``index``-th layer. ::
self.blocks = nn.Repeat(lambda index: nn.LayerChoice([...], label=f'layer{index}'), (1, 3))
Depth can be a ValueChoice to support arbitrary depth candidate list. ::
self.blocks = nn.Repeat(Block(), nn.ValueChoice([1, 3, 5]))
"""
@
classmethod
...
...
@@ -59,9 +64,15 @@ class Repeat(Mutable):
List
[
Callable
[[
int
],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
repeat
=
get_fixed_value
(
label
)
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
ValueChoice
],
*
,
label
:
Optional
[
str
]
=
None
):
if
isinstance
(
depth
,
tuple
):
# we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
depth
=
get_fixed_value
(
label
)
if
isinstance
(
depth
,
int
):
# if depth is a valuechoice, it should be already an int
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
depth
))
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# already has a mapping, will merge with it
...
...
@@ -69,8 +80,11 @@ class Repeat(Mutable):
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'blocks.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
'blocks'
})
return
result
raise
NoContextError
(
f
'Not in fixed mode, or
{
depth
}
not an integer.'
)
def
__init__
(
self
,
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
List
[
Callable
[[
int
],
nn
.
Module
]],
...
...
@@ -78,15 +92,32 @@ class Repeat(Mutable):
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
if
isinstance
(
depth
,
ValueChoiceX
):
if
label
is
not
None
:
warnings
.
warn
(
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.'
,
RuntimeWarning
)
self
.
depth_choice
=
depth
all_values
=
list
(
self
.
depth_choice
.
all_options
())
self
.
min_depth
=
min
(
all_values
)
self
.
max_depth
=
max
(
all_values
)
elif
isinstance
(
depth
,
tuple
):
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
self
.
max_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
1
]
self
.
depth_choice
=
ValueChoice
(
list
(
range
(
self
.
min_depth
,
self
.
max_depth
+
1
)),
label
=
label
)
elif
isinstance
(
depth
,
int
):
self
.
min_depth
=
self
.
max_depth
=
depth
self
.
depth_choice
=
depth
else
:
raise
TypeError
(
f
'Unsupported "depth" type:
{
type
(
depth
)
}
'
)
assert
self
.
max_depth
>=
self
.
min_depth
>
0
self
.
blocks
=
nn
.
ModuleList
(
self
.
_replicate_and_instantiate
(
blocks
,
self
.
max_depth
))
@
property
def
label
(
self
):
return
self
.
_
label
return
self
.
depth_choice
.
label
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
...
...
@@ -107,6 +138,10 @@ class Repeat(Mutable):
blocks
=
[
b
(
i
)
for
i
,
b
in
enumerate
(
blocks
)]
return
blocks
def
__getitem__
(
self
,
index
):
# shortcut for blocks[index]
return
self
.
blocks
[
index
]
class
NasBench201Cell
(
nn
.
Module
):
"""
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
ba771871
...
...
@@ -14,7 +14,7 @@ from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from
nni.retiarii.utils
import
uid
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
ValueChoiceX
,
Placeholder
from
.component
import
Repeat
,
NasBench101Cell
,
NasBench101Mutator
from
.component
import
NasBench101Cell
,
NasBench101Mutator
class
LayerChoiceMutator
(
Mutator
):
...
...
@@ -144,14 +144,15 @@ class RepeatMutator(Mutator):
return
chain
def
mutate
(
self
,
model
):
min_depth
=
self
.
nodes
[
0
].
operation
.
parameters
[
'min_depth'
]
max_depth
=
self
.
nodes
[
0
].
operation
.
parameters
[
'max_depth'
]
if
min_depth
<
max_depth
:
chosen_depth
=
self
.
choice
(
list
(
range
(
min_depth
,
max_depth
+
1
)))
for
node
in
self
.
nodes
:
# the logic here is similar to layer choice. We find cell attached to each node.
target
:
Graph
=
model
.
graphs
[
node
.
operation
.
cell_name
]
chain
=
self
.
_retrieve_chain_from_graph
(
target
)
# and we get the chosen depth (by value choice)
node_in_model
=
model
.
get_node_by_name
(
node
.
name
)
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth
=
node_in_model
.
operation
.
parameters
[
'depth'
]
for
edge
in
chain
[
chosen_depth
-
1
].
outgoing_edges
:
edge
.
remove
()
target
.
add_edge
((
chain
[
chosen_depth
-
1
],
None
),
(
target
.
output_node
,
None
))
...
...
@@ -184,6 +185,8 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
# `pc_nodes` are arguments of basic units. They can be compositions.
pc_nodes
:
List
[
Tuple
[
Node
,
str
,
ValueChoiceX
]]
=
[]
for
node
in
model
.
get_nodes
():
# arguments used in operators like Conv2d
# argument `valuechoice` used in generated repeat cell
for
name
,
choice
in
node
.
operation
.
parameters
.
items
():
if
isinstance
(
choice
,
ValueChoiceX
):
# e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
...
...
@@ -219,9 +222,10 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
repeat_nodes
=
_group_by_label
(
filter
(
lambda
d
:
d
.
operation
.
parameters
.
get
(
'mutation'
)
==
'repeat'
,
model
.
get_nodes_by_type
(
'_cell'
)))
for
node_list
in
repeat_nodes
:
# this check is not completely reliable, because it only checks max and min
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'max_depth'
],
node_list
))
and
\
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'min_depth'
],
node_list
)),
\
'Repeat with the same label must have the same
number of
candidates.'
'Repeat with the same label must have the same candidates.'
mutator
=
RepeatMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
...
...
@@ -303,11 +307,6 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if
isinstance
(
module
,
ValueChoice
):
node
=
graph
.
add_node
(
name
,
'ValueChoice'
,
{
'candidates'
:
module
.
candidates
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
Repeat
)
and
module
.
min_depth
<=
module
.
max_depth
:
node
=
graph
.
add_node
(
name
,
'Repeat'
,
{
'candidates'
:
list
(
range
(
module
.
min_depth
,
module
.
max_depth
+
1
))
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
NasBench101Cell
):
node
=
graph
.
add_node
(
name
,
'NasBench101Cell'
,
{
'max_num_edges'
:
module
.
max_num_edges
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
ba771871
...
...
@@ -66,6 +66,8 @@ def _apply_all_mutators(model, mutators, samplers):
class
GraphIR
(
unittest
.
TestCase
):
# graph engine will have an extra mutator for parameter choices
value_choice_incr
=
1
# graph engine has an extra mutator to apply the depth choice to nodes
repeat_incr
=
1
def
_convert_to_ir
(
self
,
model
):
script_module
=
torch
.
jit
.
script
(
model
)
...
...
@@ -578,14 +580,39 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
model3
=
mutator
.
apply
(
model
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
zeros
(
1
,
16
))
==
3
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
zeros
(
1
,
16
))
==
4
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model3
)(
torch
.
zeros
(
1
,
16
))
==
5
).
all
())
self
.
assertEqual
(
len
(
mutators
),
1
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
samplers
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
for
target
in
[
3
,
4
,
5
]:
new_model
=
_apply_all_mutators
(
model
,
mutators
,
samplers
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
16
))
==
target
).
all
())
def
test_repeat_static
(
self
):
class
AddOne
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
1
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
lambda
index
:
nn
.
LayerChoice
([
AddOne
(),
nn
.
Identity
()]),
4
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
4
)
sampler
=
RandomSampler
()
result
=
[]
for
_
in
range
(
50
):
new_model
=
model
for
mutator
in
mutators
:
new_model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
new_model
)
result
.
append
(
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
1
)).
item
())
for
x
in
[
1
,
2
,
3
]:
self
.
assertIn
(
float
(
x
),
result
)
def
test_repeat_complex
(
self
):
class
AddOne
(
nn
.
Module
):
...
...
@@ -602,8 +629,8 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
2
)
self
.
assertEqual
(
set
([
mutator
.
label
for
mutator
in
mutators
]),
{
'lc'
,
'rep'
})
self
.
assertEqual
(
len
(
mutators
),
2
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
self
.
assertEqual
(
set
([
mutator
.
label
for
mutator
in
mutators
if
mutator
.
label
is
not
None
]),
{
'lc'
,
'rep'
})
sampler
=
RandomSampler
()
for
_
in
range
(
10
):
...
...
@@ -624,7 +651,7 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
4
)
self
.
assertEqual
(
len
(
mutators
),
4
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
result
=
[]
for
_
in
range
(
20
):
...
...
@@ -635,6 +662,27 @@ class GraphIR(unittest.TestCase):
self
.
assertIn
(
1.
,
result
)
def
test_repeat_valuechoice
(
self
):
class
AddOne
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
1
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
AddOne
(),
nn
.
ValueChoice
([
1
,
3
,
5
]))
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
samplers
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
for
target
in
[
1
,
3
,
5
]:
new_model
=
_apply_all_mutators
(
model
,
mutators
,
samplers
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
16
))
==
target
).
all
())
def
test_repeat_weight_inheritance
(
self
):
@
model_wrapper
class
Net
(
nn
.
Module
):
...
...
@@ -647,11 +695,11 @@ class GraphIR(unittest.TestCase):
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_
sampler
(
EnumerateSampler
()
)
sampler
s
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
inp
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
4
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutator
.
apply
(
model
))
model_new
=
self
.
_get_converted_pytorch_model
(
_apply_all_mutators
(
model
,
mutators
,
samplers
))
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
...
...
@@ -778,6 +826,7 @@ class GraphIR(unittest.TestCase):
class
Python
(
GraphIR
):
# Python engine doesn't have the extra mutator
value_choice_incr
=
0
repeat_incr
=
0
def
_get_converted_pytorch_model
(
self
,
model_ir
):
mutation
=
{
mut
.
mutator
.
label
:
_unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model_ir
.
history
}
...
...
@@ -891,6 +940,8 @@ class Shared(unittest.TestCase):
elif
i
==
2
:
assert
choice
.
candidates
==
[
5
,
6
]
assert
d
.
evaluate
([
2
,
3
,
5
])
==
20
expect
=
[
x
+
y
+
3
*
z
for
x
in
[
1
,
2
]
for
y
in
[
3
,
4
]
for
z
in
[
5
,
6
]]
assert
list
(
d
.
all_options
())
==
expect
a
=
nn
.
ValueChoice
([
'cat'
,
'dog'
])
b
=
nn
.
ValueChoice
([
'milk'
,
'coffee'
])
...
...
@@ -967,6 +1018,9 @@ class Shared(unittest.TestCase):
lst
=
[
value
if
choice
.
label
==
'value'
else
divisor
for
choice
in
result
.
inner_choices
()]
assert
result
.
evaluate
(
lst
)
==
original_make_divisible
(
value
,
divisor
)
assert
len
(
list
(
result
.
all_options
()))
==
30
assert
max
(
result
.
all_options
())
==
135
def
test_valuechoice_in_evaluator
(
self
):
def
foo
():
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment