Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
ba771871
Unverified
Commit
ba771871
authored
Mar 23, 2022
by
Yuge Zhang
Committed by
GitHub
Mar 23, 2022
Browse files
Support ValueChoice as depth in Repeat (#4598)
parent
c5e3bad9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
157 additions
and
43 deletions
+157
-43
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+2
-1
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+25
-0
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+53
-18
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+10
-11
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+67
-13
No files found.
nni/retiarii/converter/graph_gen.py
View file @
ba771871
...
@@ -660,8 +660,9 @@ class GraphConverter:
...
@@ -660,8 +660,9 @@ class GraphConverter:
attrs
=
{
attrs
=
{
'mutation'
:
'repeat'
,
'mutation'
:
'repeat'
,
'label'
:
module
.
label
,
'label'
:
module
.
label
,
'depth'
:
module
.
depth_choice
,
'max_depth'
:
module
.
max_depth
,
'min_depth'
:
module
.
min_depth
,
'min_depth'
:
module
.
min_depth
,
'max_depth'
:
module
.
max_depth
}
}
return
ir_graph
,
attrs
return
ir_graph
,
attrs
...
...
nni/retiarii/nn/pytorch/api.py
View file @
ba771871
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
math
import
math
import
itertools
import
operator
import
operator
import
warnings
import
warnings
from
typing
import
Any
,
List
,
Union
,
Dict
,
Optional
,
Callable
,
Iterable
,
NoReturn
,
TypeVar
from
typing
import
Any
,
List
,
Union
,
Dict
,
Optional
,
Callable
,
Iterable
,
NoReturn
,
TypeVar
...
@@ -439,6 +440,30 @@ class ValueChoiceX(Translatable):
...
@@ -439,6 +440,30 @@ class ValueChoiceX(Translatable):
# values are not used
# values are not used
return
self
.
_evaluate
(
iter
([]),
True
)
return
self
.
_evaluate
(
iter
([]),
True
)
def
all_options
(
self
)
->
Iterable
[
Any
]:
"""Explore all possibilities of a value choice.
"""
# Record all inner choices: label -> candidates, no duplicates.
dedup_inner_choices
:
Dict
[
str
,
List
[
Any
]]
=
{}
# All labels of leaf nodes on tree, possibly duplicates.
all_labels
:
List
[
str
]
=
[]
for
choice
in
self
.
inner_choices
():
all_labels
.
append
(
choice
.
label
)
if
choice
.
label
in
dedup_inner_choices
:
if
choice
.
candidates
!=
dedup_inner_choices
[
choice
.
label
]:
# check for choice with the same label
raise
ValueError
(
f
'"
{
choice
.
candidates
}
" is not equal to "
{
dedup_inner_choices
[
choice
.
label
]
}
", '
f
'but they share the same label:
{
choice
.
label
}
'
)
else
:
dedup_inner_choices
[
choice
.
label
]
=
choice
.
candidates
dedup_labels
,
dedup_candidates
=
list
(
dedup_inner_choices
.
keys
()),
list
(
dedup_inner_choices
.
values
())
for
chosen
in
itertools
.
product
(
*
dedup_candidates
):
chosen
=
dict
(
zip
(
dedup_labels
,
chosen
))
yield
self
.
evaluate
([
chosen
[
label
]
for
label
in
all_labels
])
def
evaluate
(
self
,
values
:
Iterable
[
Any
])
->
Any
:
def
evaluate
(
self
,
values
:
Iterable
[
Any
])
->
Any
:
"""
"""
Evaluate the result of this group.
Evaluate the result of this group.
...
...
nni/retiarii/nn/pytorch/component.py
View file @
ba771871
import
copy
import
copy
import
warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING_PARTIAL
from
nni.retiarii.utils
import
NoContextError
,
STATE_DICT_PY_MAPPING_PARTIAL
from
.api
import
LayerChoice
from
.api
import
LayerChoice
,
ValueChoice
,
ValueChoiceX
from
.cell
import
Cell
from
.cell
import
Cell
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.mutation_utils
import
Mutable
,
generate_new_label
,
get_fixed_value
...
@@ -30,7 +31,7 @@ class Repeat(Mutable):
...
@@ -30,7 +31,7 @@ class Repeat(Mutable):
depth : int or tuple of int
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least ``min`` times and at most ``max`` times.
meaning that the block will be repeated at least ``min`` times and at most ``max`` times.
If a ValueChoice, it should choose from a series of positive integers.
Examples
Examples
--------
--------
...
@@ -51,6 +52,10 @@ class Repeat(Mutable):
...
@@ -51,6 +52,10 @@ class Repeat(Mutable):
we need a factory function that accepts index (0, 1, 2, ...) and returns the module of the ``index``-th layer. ::
we need a factory function that accepts index (0, 1, 2, ...) and returns the module of the ``index``-th layer. ::
self.blocks = nn.Repeat(lambda index: nn.LayerChoice([...], label=f'layer{index}'), (1, 3))
self.blocks = nn.Repeat(lambda index: nn.LayerChoice([...], label=f'layer{index}'), (1, 3))
Depth can be a ValueChoice to support arbitrary depth candidate list. ::
self.blocks = nn.Repeat(Block(), nn.ValueChoice([1, 3, 5]))
"""
"""
@
classmethod
@
classmethod
...
@@ -59,17 +64,26 @@ class Repeat(Mutable):
...
@@ -59,17 +64,26 @@ class Repeat(Mutable):
List
[
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]],
nn
.
Module
,
nn
.
Module
,
List
[
nn
.
Module
]],
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
depth
:
Union
[
int
,
Tuple
[
int
,
int
],
ValueChoice
],
*
,
label
:
Optional
[
str
]
=
None
):
repeat
=
get_fixed_value
(
label
)
if
isinstance
(
depth
,
tuple
):
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
# we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
depth
=
get_fixed_value
(
label
)
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
if
isinstance
(
depth
,
int
):
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'blocks.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
# if depth is a valuechoice, it should be already an int
else
:
result
=
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
depth
))
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
'blocks'
})
return
result
if
hasattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
):
# already has a mapping, will merge with it
prev_mapping
=
getattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
)
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
k
:
f
'blocks.
{
v
}
'
for
k
,
v
in
prev_mapping
.
items
()})
else
:
setattr
(
result
,
STATE_DICT_PY_MAPPING_PARTIAL
,
{
'__self__'
:
'blocks'
})
return
result
raise
NoContextError
(
f
'Not in fixed mode, or
{
depth
}
not an integer.'
)
def
__init__
(
self
,
def
__init__
(
self
,
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
...
@@ -78,15 +92,32 @@ class Repeat(Mutable):
...
@@ -78,15 +92,32 @@ class Repeat(Mutable):
List
[
nn
.
Module
]],
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
if
isinstance
(
depth
,
ValueChoiceX
):
self
.
max_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
1
]
if
label
is
not
None
:
warnings
.
warn
(
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.'
,
RuntimeWarning
)
self
.
depth_choice
=
depth
all_values
=
list
(
self
.
depth_choice
.
all_options
())
self
.
min_depth
=
min
(
all_values
)
self
.
max_depth
=
max
(
all_values
)
elif
isinstance
(
depth
,
tuple
):
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
self
.
max_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
1
]
self
.
depth_choice
=
ValueChoice
(
list
(
range
(
self
.
min_depth
,
self
.
max_depth
+
1
)),
label
=
label
)
elif
isinstance
(
depth
,
int
):
self
.
min_depth
=
self
.
max_depth
=
depth
self
.
depth_choice
=
depth
else
:
raise
TypeError
(
f
'Unsupported "depth" type:
{
type
(
depth
)
}
'
)
assert
self
.
max_depth
>=
self
.
min_depth
>
0
assert
self
.
max_depth
>=
self
.
min_depth
>
0
self
.
blocks
=
nn
.
ModuleList
(
self
.
_replicate_and_instantiate
(
blocks
,
self
.
max_depth
))
self
.
blocks
=
nn
.
ModuleList
(
self
.
_replicate_and_instantiate
(
blocks
,
self
.
max_depth
))
@
property
@
property
def
label
(
self
):
def
label
(
self
):
return
self
.
_
label
return
self
.
depth_choice
.
label
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
...
@@ -107,6 +138,10 @@ class Repeat(Mutable):
...
@@ -107,6 +138,10 @@ class Repeat(Mutable):
blocks
=
[
b
(
i
)
for
i
,
b
in
enumerate
(
blocks
)]
blocks
=
[
b
(
i
)
for
i
,
b
in
enumerate
(
blocks
)]
return
blocks
return
blocks
def
__getitem__
(
self
,
index
):
# shortcut for blocks[index]
return
self
.
blocks
[
index
]
class
NasBench201Cell
(
nn
.
Module
):
class
NasBench201Cell
(
nn
.
Module
):
"""
"""
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
ba771871
...
@@ -14,7 +14,7 @@ from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
...
@@ -14,7 +14,7 @@ from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from
nni.retiarii.utils
import
uid
from
nni.retiarii.utils
import
uid
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
ValueChoiceX
,
Placeholder
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
ValueChoiceX
,
Placeholder
from
.component
import
Repeat
,
NasBench101Cell
,
NasBench101Mutator
from
.component
import
NasBench101Cell
,
NasBench101Mutator
class
LayerChoiceMutator
(
Mutator
):
class
LayerChoiceMutator
(
Mutator
):
...
@@ -144,14 +144,15 @@ class RepeatMutator(Mutator):
...
@@ -144,14 +144,15 @@ class RepeatMutator(Mutator):
return
chain
return
chain
def
mutate
(
self
,
model
):
def
mutate
(
self
,
model
):
min_depth
=
self
.
nodes
[
0
].
operation
.
parameters
[
'min_depth'
]
max_depth
=
self
.
nodes
[
0
].
operation
.
parameters
[
'max_depth'
]
if
min_depth
<
max_depth
:
chosen_depth
=
self
.
choice
(
list
(
range
(
min_depth
,
max_depth
+
1
)))
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
# the logic here is similar to layer choice. We find cell attached to each node.
# the logic here is similar to layer choice. We find cell attached to each node.
target
:
Graph
=
model
.
graphs
[
node
.
operation
.
cell_name
]
target
:
Graph
=
model
.
graphs
[
node
.
operation
.
cell_name
]
chain
=
self
.
_retrieve_chain_from_graph
(
target
)
chain
=
self
.
_retrieve_chain_from_graph
(
target
)
# and we get the chosen depth (by value choice)
node_in_model
=
model
.
get_node_by_name
(
node
.
name
)
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth
=
node_in_model
.
operation
.
parameters
[
'depth'
]
for
edge
in
chain
[
chosen_depth
-
1
].
outgoing_edges
:
for
edge
in
chain
[
chosen_depth
-
1
].
outgoing_edges
:
edge
.
remove
()
edge
.
remove
()
target
.
add_edge
((
chain
[
chosen_depth
-
1
],
None
),
(
target
.
output_node
,
None
))
target
.
add_edge
((
chain
[
chosen_depth
-
1
],
None
),
(
target
.
output_node
,
None
))
...
@@ -184,6 +185,8 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
...
@@ -184,6 +185,8 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
# `pc_nodes` are arguments of basic units. They can be compositions.
# `pc_nodes` are arguments of basic units. They can be compositions.
pc_nodes
:
List
[
Tuple
[
Node
,
str
,
ValueChoiceX
]]
=
[]
pc_nodes
:
List
[
Tuple
[
Node
,
str
,
ValueChoiceX
]]
=
[]
for
node
in
model
.
get_nodes
():
for
node
in
model
.
get_nodes
():
# arguments used in operators like Conv2d
# argument `valuechoice` used in generated repeat cell
for
name
,
choice
in
node
.
operation
.
parameters
.
items
():
for
name
,
choice
in
node
.
operation
.
parameters
.
items
():
if
isinstance
(
choice
,
ValueChoiceX
):
if
isinstance
(
choice
,
ValueChoiceX
):
# e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
# e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
...
@@ -219,9 +222,10 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
...
@@ -219,9 +222,10 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
repeat_nodes
=
_group_by_label
(
filter
(
lambda
d
:
d
.
operation
.
parameters
.
get
(
'mutation'
)
==
'repeat'
,
repeat_nodes
=
_group_by_label
(
filter
(
lambda
d
:
d
.
operation
.
parameters
.
get
(
'mutation'
)
==
'repeat'
,
model
.
get_nodes_by_type
(
'_cell'
)))
model
.
get_nodes_by_type
(
'_cell'
)))
for
node_list
in
repeat_nodes
:
for
node_list
in
repeat_nodes
:
# this check is not completely reliable, because it only checks max and min
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'max_depth'
],
node_list
))
and
\
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'max_depth'
],
node_list
))
and
\
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'min_depth'
],
node_list
)),
\
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'min_depth'
],
node_list
)),
\
'Repeat with the same label must have the same
number of
candidates.'
'Repeat with the same label must have the same candidates.'
mutator
=
RepeatMutator
(
node_list
)
mutator
=
RepeatMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
applied_mutators
.
append
(
mutator
)
...
@@ -303,11 +307,6 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
...
@@ -303,11 +307,6 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if
isinstance
(
module
,
ValueChoice
):
if
isinstance
(
module
,
ValueChoice
):
node
=
graph
.
add_node
(
name
,
'ValueChoice'
,
{
'candidates'
:
module
.
candidates
})
node
=
graph
.
add_node
(
name
,
'ValueChoice'
,
{
'candidates'
:
module
.
candidates
})
node
.
label
=
module
.
label
node
.
label
=
module
.
label
if
isinstance
(
module
,
Repeat
)
and
module
.
min_depth
<=
module
.
max_depth
:
node
=
graph
.
add_node
(
name
,
'Repeat'
,
{
'candidates'
:
list
(
range
(
module
.
min_depth
,
module
.
max_depth
+
1
))
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
NasBench101Cell
):
if
isinstance
(
module
,
NasBench101Cell
):
node
=
graph
.
add_node
(
name
,
'NasBench101Cell'
,
{
node
=
graph
.
add_node
(
name
,
'NasBench101Cell'
,
{
'max_num_edges'
:
module
.
max_num_edges
'max_num_edges'
:
module
.
max_num_edges
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
ba771871
...
@@ -66,6 +66,8 @@ def _apply_all_mutators(model, mutators, samplers):
...
@@ -66,6 +66,8 @@ def _apply_all_mutators(model, mutators, samplers):
class
GraphIR
(
unittest
.
TestCase
):
class
GraphIR
(
unittest
.
TestCase
):
# graph engine will have an extra mutator for parameter choices
# graph engine will have an extra mutator for parameter choices
value_choice_incr
=
1
value_choice_incr
=
1
# graph engine has an extra mutator to apply the depth choice to nodes
repeat_incr
=
1
def
_convert_to_ir
(
self
,
model
):
def
_convert_to_ir
(
self
,
model
):
script_module
=
torch
.
jit
.
script
(
model
)
script_module
=
torch
.
jit
.
script
(
model
)
...
@@ -578,14 +580,39 @@ class GraphIR(unittest.TestCase):
...
@@ -578,14 +580,39 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
)
self
.
assertEqual
(
len
(
mutators
),
1
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
samplers
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
model1
=
mutator
.
apply
(
model
)
for
target
in
[
3
,
4
,
5
]:
model2
=
mutator
.
apply
(
model
)
new_model
=
_apply_all_mutators
(
model
,
mutators
,
samplers
)
model3
=
mutator
.
apply
(
model
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
16
))
==
target
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
zeros
(
1
,
16
))
==
3
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
zeros
(
1
,
16
))
==
4
).
all
())
def
test_repeat_static
(
self
):
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model3
)(
torch
.
zeros
(
1
,
16
))
==
5
).
all
())
class
AddOne
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
1
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
lambda
index
:
nn
.
LayerChoice
([
AddOne
(),
nn
.
Identity
()]),
4
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
4
)
sampler
=
RandomSampler
()
result
=
[]
for
_
in
range
(
50
):
new_model
=
model
for
mutator
in
mutators
:
new_model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
new_model
)
result
.
append
(
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
1
)).
item
())
for
x
in
[
1
,
2
,
3
]:
self
.
assertIn
(
float
(
x
),
result
)
def
test_repeat_complex
(
self
):
def
test_repeat_complex
(
self
):
class
AddOne
(
nn
.
Module
):
class
AddOne
(
nn
.
Module
):
...
@@ -602,8 +629,8 @@ class GraphIR(unittest.TestCase):
...
@@ -602,8 +629,8 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
2
)
self
.
assertEqual
(
len
(
mutators
),
2
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
self
.
assertEqual
(
set
([
mutator
.
label
for
mutator
in
mutators
]),
{
'lc'
,
'rep'
})
self
.
assertEqual
(
set
([
mutator
.
label
for
mutator
in
mutators
if
mutator
.
label
is
not
None
]),
{
'lc'
,
'rep'
})
sampler
=
RandomSampler
()
sampler
=
RandomSampler
()
for
_
in
range
(
10
):
for
_
in
range
(
10
):
...
@@ -624,7 +651,7 @@ class GraphIR(unittest.TestCase):
...
@@ -624,7 +651,7 @@ class GraphIR(unittest.TestCase):
return
self
.
block
(
x
)
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
4
)
self
.
assertEqual
(
len
(
mutators
),
4
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
result
=
[]
result
=
[]
for
_
in
range
(
20
):
for
_
in
range
(
20
):
...
@@ -635,6 +662,27 @@ class GraphIR(unittest.TestCase):
...
@@ -635,6 +662,27 @@ class GraphIR(unittest.TestCase):
self
.
assertIn
(
1.
,
result
)
self
.
assertIn
(
1.
,
result
)
def
test_repeat_valuechoice
(
self
):
class
AddOne
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
1
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
AddOne
(),
nn
.
ValueChoice
([
1
,
3
,
5
]))
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
+
self
.
repeat_incr
+
self
.
value_choice_incr
)
samplers
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
for
target
in
[
1
,
3
,
5
]:
new_model
=
_apply_all_mutators
(
model
,
mutators
,
samplers
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
16
))
==
target
).
all
())
def
test_repeat_weight_inheritance
(
self
):
def
test_repeat_weight_inheritance
(
self
):
@
model_wrapper
@
model_wrapper
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
@@ -647,11 +695,11 @@ class GraphIR(unittest.TestCase):
...
@@ -647,11 +695,11 @@ class GraphIR(unittest.TestCase):
orig_model
=
Net
()
orig_model
=
Net
()
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
orig_model
)
mutator
=
mutators
[
0
].
bind_
sampler
(
EnumerateSampler
()
)
sampler
s
=
[
EnumerateSampler
()
for
_
in
range
(
len
(
mutators
))]
inp
=
torch
.
randn
(
1
,
3
,
5
,
5
)
inp
=
torch
.
randn
(
1
,
3
,
5
,
5
)
for
i
in
range
(
4
):
for
i
in
range
(
4
):
model_new
=
self
.
_get_converted_pytorch_model
(
mutator
.
apply
(
model
))
model_new
=
self
.
_get_converted_pytorch_model
(
_apply_all_mutators
(
model
,
mutators
,
samplers
))
with
original_state_dict_hooks
(
model_new
):
with
original_state_dict_hooks
(
model_new
):
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
model_new
.
load_state_dict
(
orig_model
.
state_dict
(),
strict
=
False
)
...
@@ -778,6 +826,7 @@ class GraphIR(unittest.TestCase):
...
@@ -778,6 +826,7 @@ class GraphIR(unittest.TestCase):
class
Python
(
GraphIR
):
class
Python
(
GraphIR
):
# Python engine doesn't have the extra mutator
# Python engine doesn't have the extra mutator
value_choice_incr
=
0
value_choice_incr
=
0
repeat_incr
=
0
def
_get_converted_pytorch_model
(
self
,
model_ir
):
def
_get_converted_pytorch_model
(
self
,
model_ir
):
mutation
=
{
mut
.
mutator
.
label
:
_unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model_ir
.
history
}
mutation
=
{
mut
.
mutator
.
label
:
_unpack_if_only_one
(
mut
.
samples
)
for
mut
in
model_ir
.
history
}
...
@@ -891,6 +940,8 @@ class Shared(unittest.TestCase):
...
@@ -891,6 +940,8 @@ class Shared(unittest.TestCase):
elif
i
==
2
:
elif
i
==
2
:
assert
choice
.
candidates
==
[
5
,
6
]
assert
choice
.
candidates
==
[
5
,
6
]
assert
d
.
evaluate
([
2
,
3
,
5
])
==
20
assert
d
.
evaluate
([
2
,
3
,
5
])
==
20
expect
=
[
x
+
y
+
3
*
z
for
x
in
[
1
,
2
]
for
y
in
[
3
,
4
]
for
z
in
[
5
,
6
]]
assert
list
(
d
.
all_options
())
==
expect
a
=
nn
.
ValueChoice
([
'cat'
,
'dog'
])
a
=
nn
.
ValueChoice
([
'cat'
,
'dog'
])
b
=
nn
.
ValueChoice
([
'milk'
,
'coffee'
])
b
=
nn
.
ValueChoice
([
'milk'
,
'coffee'
])
...
@@ -967,6 +1018,9 @@ class Shared(unittest.TestCase):
...
@@ -967,6 +1018,9 @@ class Shared(unittest.TestCase):
lst
=
[
value
if
choice
.
label
==
'value'
else
divisor
for
choice
in
result
.
inner_choices
()]
lst
=
[
value
if
choice
.
label
==
'value'
else
divisor
for
choice
in
result
.
inner_choices
()]
assert
result
.
evaluate
(
lst
)
==
original_make_divisible
(
value
,
divisor
)
assert
result
.
evaluate
(
lst
)
==
original_make_divisible
(
value
,
divisor
)
assert
len
(
list
(
result
.
all_options
()))
==
30
assert
max
(
result
.
all_options
())
==
135
def
test_valuechoice_in_evaluator
(
self
):
def
test_valuechoice_in_evaluator
(
self
):
def
foo
():
def
foo
():
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment