Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6808708d
Unverified
Commit
6808708d
authored
Apr 09, 2021
by
Yuge Zhang
Committed by
GitHub
Apr 09, 2021
Browse files
[Retiarii] Nest `ValueChoice` in `LayerChoice` and dict/list in `ValueChoice` (#3508)
parent
b7062b5d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
196 additions
and
45 deletions
+196
-45
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+11
-14
nni/retiarii/graph.py
nni/retiarii/graph.py
+7
-5
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+29
-2
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+31
-16
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+3
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+109
-2
test/ut/retiarii/test_strategy.py
test/ut/retiarii/test_strategy.py
+6
-6
No files found.
nni/retiarii/converter/graph_gen.py
View file @
6808708d
...
@@ -6,7 +6,7 @@ import re
...
@@ -6,7 +6,7 @@ import re
import
torch
import
torch
from
..graph
import
Graph
,
Model
,
Node
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..nn.pytorch
import
InputChoice
,
Placeholder
from
..operation
import
Cell
,
Operation
from
..operation
import
Cell
,
Operation
from
..serializer
import
get_init_parameters_or_fail
from
..serializer
import
get_init_parameters_or_fail
from
..utils
import
get_importable_name
from
..utils
import
get_importable_name
...
@@ -343,7 +343,7 @@ class GraphConverter:
...
@@ -343,7 +343,7 @@ class GraphConverter:
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
subcell
=
ir_graph
.
add_node
(
submodule_full_name
,
submodule_type_str
,
sub_m_attrs
)
if
isinstance
(
submodule_obj
,
Placeholder
):
if
isinstance
(
submodule_obj
,
Placeholder
):
subcell
.
update_label
(
submodule_obj
.
label
)
subcell
.
update_label
(
submodule_obj
.
label
)
elif
isinstance
(
submodule_obj
,
(
LayerChoice
,
InputChoice
)
)
:
elif
isinstance
(
submodule_obj
,
InputChoice
):
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
subcell
.
update_label
(
sub_m_attrs
[
'label'
])
else
:
else
:
# Graph already created, create Cell for it
# Graph already created, create Cell for it
...
@@ -536,16 +536,6 @@ class GraphConverter:
...
@@ -536,16 +536,6 @@ class GraphConverter:
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
self
.
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
self
.
merge_aten_slices
(
ir_graph
)
self
.
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
self
,
module
):
choices
=
[]
for
cand
in
list
(
module
):
cand_type
=
'__torch__.'
+
get_importable_name
(
cand
.
__class__
)
choices
.
append
({
'type'
:
cand_type
,
'parameters'
:
get_init_parameters_or_fail
(
cand
)})
return
{
'candidates'
:
choices
,
'label'
:
module
.
label
}
def
_handle_inputchoice
(
self
,
module
):
def
_handle_inputchoice
(
self
,
module
):
return
{
return
{
'n_candidates'
:
module
.
n_candidates
,
'n_candidates'
:
module
.
n_candidates
,
...
@@ -557,7 +547,8 @@ class GraphConverter:
...
@@ -557,7 +547,8 @@ class GraphConverter:
def
_handle_valuechoice
(
self
,
module
):
def
_handle_valuechoice
(
self
,
module
):
return
{
return
{
'candidates'
:
module
.
candidates
,
'candidates'
:
module
.
candidates
,
'label'
:
module
.
label
'label'
:
module
.
label
,
'accessor'
:
module
.
_accessor
}
}
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
def
convert_module
(
self
,
script_module
,
module
,
module_name
,
ir_model
):
...
@@ -590,7 +581,13 @@ class GraphConverter:
...
@@ -590,7 +581,13 @@ class GraphConverter:
if
original_type_name
in
MODULE_EXCEPT_LIST
:
if
original_type_name
in
MODULE_EXCEPT_LIST
:
pass
# do nothing
pass
# do nothing
elif
original_type_name
==
OpTypeName
.
LayerChoice
:
elif
original_type_name
==
OpTypeName
.
LayerChoice
:
m_attrs
=
self
.
_handle_layerchoice
(
module
)
graph
=
Graph
(
ir_model
,
-
100
,
module_name
,
_internal
=
True
)
# graph_id is not used now
candidate_name_list
=
[
f
'layerchoice_
{
module
.
label
}
_
{
cand_name
}
'
for
cand_name
in
module
.
names
]
for
cand_name
,
cand
in
zip
(
candidate_name_list
,
module
):
cand_type
=
'__torch__.'
+
get_importable_name
(
cand
.
__class__
)
graph
.
add_node
(
cand_name
,
cand_type
,
get_init_parameters_or_fail
(
cand
))
graph
.
_register
()
return
graph
,
{
'mutation'
:
'layerchoice'
,
'label'
:
module
.
label
,
'candidates'
:
candidate_name_list
}
elif
original_type_name
==
OpTypeName
.
InputChoice
:
elif
original_type_name
==
OpTypeName
.
InputChoice
:
m_attrs
=
self
.
_handle_inputchoice
(
module
)
m_attrs
=
self
.
_handle_inputchoice
(
module
)
elif
original_type_name
==
OpTypeName
.
ValueChoice
:
elif
original_type_name
==
OpTypeName
.
ValueChoice
:
...
...
nni/retiarii/graph.py
View file @
6808708d
...
@@ -144,15 +144,17 @@ class Model:
...
@@ -144,15 +144,17 @@ class Model:
for
graph_name
,
graph_data
in
ir
.
items
():
for
graph_name
,
graph_data
in
ir
.
items
():
if
graph_name
!=
'_evaluator'
:
if
graph_name
!=
'_evaluator'
:
Graph
.
_load
(
model
,
graph_name
,
graph_data
).
_register
()
Graph
.
_load
(
model
,
graph_name
,
graph_data
).
_register
()
model
.
evaluator
=
Evaluator
.
_load_with_type
(
ir
[
'_evaluator'
][
'__type__'
],
ir
[
'_evaluator'
])
if
'_evaluator'
in
ir
:
model
.
evaluator
=
Evaluator
.
_load_with_type
(
ir
[
'_evaluator'
][
'__type__'
],
ir
[
'_evaluator'
])
return
model
return
model
def
_dump
(
self
)
->
Any
:
def
_dump
(
self
)
->
Any
:
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
ret
[
'_evaluator'
]
=
{
if
self
.
evaluator
is
not
None
:
'__type__'
:
get_importable_name
(
self
.
evaluator
.
__class__
),
ret
[
'_evaluator'
]
=
{
**
self
.
evaluator
.
_dump
()
'__type__'
:
get_importable_name
(
self
.
evaluator
.
__class__
),
}
**
self
.
evaluator
.
_dump
()
}
return
ret
return
ret
def
get_nodes
(
self
)
->
Iterable
[
'Node'
]:
def
get_nodes
(
self
)
->
Iterable
[
'Node'
]:
...
...
nni/retiarii/nn/pytorch/api.py
View file @
6808708d
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
copy
import
warnings
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
List
,
Union
,
Dict
from
typing
import
Any
,
List
,
Union
,
Dict
import
warnings
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -268,6 +269,7 @@ class ValueChoice(Translatable, nn.Module):
...
@@ -268,6 +269,7 @@ class ValueChoice(Translatable, nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
candidates
=
candidates
self
.
candidates
=
candidates
self
.
_label
=
label
if
label
is
not
None
else
f
'valuechoice_
{
uid
()
}
'
self
.
_label
=
label
if
label
is
not
None
else
f
'valuechoice_
{
uid
()
}
'
self
.
_accessor
=
[]
@
property
@
property
def
label
(
self
):
def
label
(
self
):
...
@@ -279,11 +281,36 @@ class ValueChoice(Translatable, nn.Module):
...
@@ -279,11 +281,36 @@ class ValueChoice(Translatable, nn.Module):
def
_translate
(
self
):
def
_translate
(
self
):
# Will function as a value when used in serializer.
# Will function as a value when used in serializer.
return
self
.
candidates
[
0
]
return
self
.
access
(
self
.
candidates
[
0
]
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'ValueChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
return
f
'ValueChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
def
access
(
self
,
value
):
if
not
self
.
_accessor
:
return
value
try
:
v
=
value
for
a
in
self
.
_accessor
:
v
=
v
[
a
]
except
KeyError
:
raise
KeyError
(
''
.
join
([
f
'[
{
a
}
]'
for
a
in
self
.
_accessor
])
+
f
' does not work on
{
value
}
'
)
return
v
def
__getitem__
(
self
,
item
):
"""
Get a sub-element of value choice.
The underlying implementation is to clone the current instance, and append item to "accessor", which records all
the history getitem calls. For example, when accessor is ``[a, b, c]``, the value choice will return ``vc[a][b][c]``
where ``vc`` is the original value choice.
"""
access
=
copy
.
deepcopy
(
self
)
access
.
_accessor
.
append
(
item
)
for
candidate
in
self
.
candidates
:
access
.
access
(
candidate
)
return
access
@
basic_unit
@
basic_unit
class
Placeholder
(
nn
.
Module
):
class
Placeholder
(
nn
.
Module
):
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
6808708d
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
...mutator
import
Mutator
from
...mutator
import
Mutator
from
...graph
import
Model
,
Node
from
...graph
import
Cell
,
Model
,
Node
from
.api
import
ValueChoice
from
.api
import
ValueChoice
...
@@ -14,13 +14,23 @@ class LayerChoiceMutator(Mutator):
...
@@ -14,13 +14,23 @@ class LayerChoiceMutator(Mutator):
self
.
nodes
=
nodes
self
.
nodes
=
nodes
def
mutate
(
self
,
model
):
def
mutate
(
self
,
model
):
n_candidates
=
len
(
self
.
nodes
[
0
].
operation
.
parameters
[
'candidates'
])
candidates
=
self
.
nodes
[
0
].
operation
.
parameters
[
'candidates'
]
indices
=
list
(
range
(
n_candidates
))
chosen
=
self
.
choice
(
candidates
)
chosen_index
=
self
.
choice
(
indices
)
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
target
=
model
.
get_node_by_name
(
node
.
name
)
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
chosen_cand
=
node
.
operation
.
parameters
[
'candidates'
][
chosen_index
]
# We add the connections here in the mutation logic.
target
.
update_operation
(
chosen_cand
[
'type'
],
chosen_cand
[
'parameters'
])
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target
=
model
.
graphs
[
node
.
operation
.
cell_name
]
chosen_node
=
target
.
get_node_by_name
(
chosen
)
assert
chosen_node
is
not
None
target
.
add_edge
((
target
.
input_node
,
0
),
(
chosen_node
,
None
))
target
.
add_edge
((
chosen_node
,
None
),
(
target
.
output_node
,
None
))
model
.
get_node_by_name
(
node
.
name
).
update_operation
(
Cell
(
node
.
operation
.
cell_name
))
# remove redundant nodes
for
rm_node
in
target
.
hidden_nodes
:
if
rm_node
.
name
!=
chosen_node
.
name
:
rm_node
.
remove
()
class
InputChoiceMutator
(
Mutator
):
class
InputChoiceMutator
(
Mutator
):
...
@@ -61,20 +71,14 @@ class ParameterChoiceMutator(Mutator):
...
@@ -61,20 +71,14 @@ class ParameterChoiceMutator(Mutator):
def
mutate
(
self
,
model
):
def
mutate
(
self
,
model
):
chosen
=
self
.
choice
(
self
.
candidates
)
chosen
=
self
.
choice
(
self
.
candidates
)
for
node
,
argname
in
self
.
nodes
:
for
node
,
argname
in
self
.
nodes
:
chosen_value
=
node
.
operation
.
parameters
[
argname
].
access
(
chosen
)
target
=
model
.
get_node_by_name
(
node
.
name
)
target
=
model
.
get_node_by_name
(
node
.
name
)
target
.
update_operation
(
target
.
operation
.
type
,
{
**
target
.
operation
.
parameters
,
argname
:
chosen
})
target
.
update_operation
(
target
.
operation
.
type
,
{
**
target
.
operation
.
parameters
,
argname
:
chosen
_value
})
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
applied_mutators
=
[]
applied_mutators
=
[]
lc_nodes
=
_group_by_label
(
model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.api.LayerChoice'
))
for
node_list
in
lc_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
len
(
node
.
operation
.
parameters
[
'candidates'
]),
node_list
)),
\
'Layer choice with the same label must have the same number of candidates.'
mutator
=
LayerChoiceMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
ic_nodes
=
_group_by_label
(
model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.api.InputChoice'
))
ic_nodes
=
_group_by_label
(
model
.
get_nodes_by_type
(
'__torch__.nni.retiarii.nn.pytorch.api.InputChoice'
))
for
node_list
in
ic_nodes
:
for
node_list
in
ic_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'n_candidates'
],
node_list
))
and
\
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'n_candidates'
],
node_list
))
and
\
...
@@ -99,9 +103,20 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
...
@@ -99,9 +103,20 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
for
node_list
in
pc_nodes
:
for
node_list
in
pc_nodes
:
assert
_is_all_equal
([
node
.
operation
.
parameters
[
name
].
candidates
for
node
,
name
in
node_list
]),
\
assert
_is_all_equal
([
node
.
operation
.
parameters
[
name
].
candidates
for
node
,
name
in
node_list
]),
\
'Value choice with the same label must have the same candidates.'
'Value choice with the same label must have the same candidates.'
mutator
=
ParameterChoiceMutator
(
node_list
,
node_list
[
0
][
0
].
operation
.
parameters
[
node_list
[
0
][
1
]].
candidates
)
first_node
,
first_argname
=
node_list
[
0
]
mutator
=
ParameterChoiceMutator
(
node_list
,
first_node
.
operation
.
parameters
[
first_argname
].
candidates
)
applied_mutators
.
append
(
mutator
)
applied_mutators
.
append
(
mutator
)
# apply layer choice at last as it will delete some nodes
lc_nodes
=
_group_by_label
(
filter
(
lambda
d
:
d
.
operation
.
parameters
.
get
(
'mutation'
)
==
'layerchoice'
,
model
.
get_nodes_by_type
(
'_cell'
)))
for
node_list
in
lc_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
len
(
node
.
operation
.
parameters
[
'candidates'
]),
node_list
)),
\
'Layer choice with the same label must have the same number of candidates.'
mutator
=
LayerChoiceMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
if
applied_mutators
:
if
applied_mutators
:
return
applied_mutators
return
applied_mutators
return
None
return
None
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
6808708d
...
@@ -69,6 +69,9 @@ class PrimConstant(PyTorchOperation):
...
@@ -69,6 +69,9 @@ class PrimConstant(PyTorchOperation):
elif
self
.
parameters
[
'type'
]
==
'Device'
:
elif
self
.
parameters
[
'type'
]
==
'Device'
:
value
=
self
.
parameters
[
'value'
]
value
=
self
.
parameters
[
'value'
]
return
f
'
{
output
}
= torch.device("
{
value
}
")'
return
f
'
{
output
}
= torch.device("
{
value
}
")'
elif
self
.
parameters
[
'type'
]
in
(
'dict'
,
'list'
,
'tuple'
):
# TODO: prim::TupleIndex is not supported yet
return
f
'
{
output
}
=
{
repr
(
self
.
parameters
[
"value"
])
}
'
else
:
else
:
raise
RuntimeError
(
f
'unsupported type of prim::Constant:
{
self
.
parameters
[
"type"
]
}
'
)
raise
RuntimeError
(
f
'unsupported type of prim::Constant:
{
self
.
parameters
[
"type"
]
}
'
)
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
6808708d
import
random
import
random
import
unittest
import
unittest
from
collections
import
Counter
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
import
torch
import
torch
...
@@ -252,6 +253,30 @@ class TestHighLevelAPI(unittest.TestCase):
...
@@ -252,6 +253,30 @@ class TestHighLevelAPI(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
def
test_value_choice_in_layer_choice
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
linear
=
nn
.
LayerChoice
([
nn
.
Linear
(
3
,
nn
.
ValueChoice
([
10
,
20
])),
nn
.
Linear
(
3
,
nn
.
ValueChoice
([
30
,
40
]))
])
def
forward
(
self
,
x
):
return
self
.
linear
(
x
)
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
3
)
sz_counter
=
Counter
()
sampler
=
RandomSampler
()
for
i
in
range
(
100
):
model_new
=
model
for
mutator
in
mutators
:
model_new
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model_new
)
sz_counter
[
self
.
_get_converted_pytorch_model
(
model_new
)(
torch
.
randn
(
1
,
3
)).
size
(
1
)]
+=
1
self
.
assertEqual
(
len
(
sz_counter
),
4
)
def
test_shared
(
self
):
def
test_shared
(
self
):
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
,
shared
=
True
):
def
__init__
(
self
,
shared
=
True
):
...
@@ -284,12 +309,94 @@ class TestHighLevelAPI(unittest.TestCase):
...
@@ -284,12 +309,94 @@ class TestHighLevelAPI(unittest.TestCase):
# repeat test. Expectation: sometimes succeeds, sometimes fails.
# repeat test. Expectation: sometimes succeeds, sometimes fails.
failed_count
=
0
failed_count
=
0
for
i
in
range
(
30
):
for
i
in
range
(
30
):
model_new
=
model
for
mutator
in
mutators
:
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
model
_new
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
_new
)
self
.
assertEqual
(
sampler
.
counter
,
2
*
(
i
+
1
))
self
.
assertEqual
(
sampler
.
counter
,
2
*
(
i
+
1
))
try
:
try
:
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
_get_converted_pytorch_model
(
model
_new
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
except
RuntimeError
:
except
RuntimeError
:
failed_count
+=
1
failed_count
+=
1
self
.
assertGreater
(
failed_count
,
0
)
self
.
assertGreater
(
failed_count
,
0
)
self
.
assertLess
(
failed_count
,
30
)
self
.
assertLess
(
failed_count
,
30
)
def
test_valuechoice_access
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
vc
=
nn
.
ValueChoice
([(
6
,
3
),
(
8
,
5
)])
self
.
conv
=
nn
.
Conv2d
(
3
,
vc
[
0
],
kernel_size
=
vc
[
1
])
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
input
=
torch
.
randn
(
1
,
3
,
5
,
5
)
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutators
[
0
].
apply
(
model
))(
input
).
size
(),
torch
.
Size
([
1
,
6
,
3
,
3
]))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
mutators
[
0
].
apply
(
model
))(
input
).
size
(),
torch
.
Size
([
1
,
8
,
1
,
1
]))
class
Net2
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
choices
=
[
{
'b'
:
[
3
],
'bp'
:
[
6
]},
{
'b'
:
[
6
],
'bp'
:
[
12
]}
]
self
.
conv
=
nn
.
Conv2d
(
3
,
nn
.
ValueChoice
(
choices
,
label
=
'a'
)[
'b'
][
0
],
1
)
self
.
conv1
=
nn
.
Conv2d
(
nn
.
ValueChoice
(
choices
,
label
=
'a'
)[
'bp'
][
0
],
3
,
1
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
self
.
conv1
(
torch
.
cat
((
x
,
x
),
1
))
model
=
self
.
_convert_to_ir
(
Net2
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
input
=
torch
.
randn
(
1
,
3
,
5
,
5
)
self
.
_get_converted_pytorch_model
(
mutators
[
0
].
apply
(
model
))(
input
)
def
test_valuechoice_access_functional
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
dropout_rate
=
nn
.
ValueChoice
([[
0.
,],
[
1.
,]])
def
forward
(
self
,
x
):
return
F
.
dropout
(
x
,
self
.
dropout_rate
()[
0
])
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
def
test_valuechoice_access_functional_expression
(
self
):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
dropout_rate
=
nn
.
ValueChoice
([[
1.05
,],
[
1.1
,]])
def
forward
(
self
,
x
):
# if expression failed, the exception would be:
# ValueError: dropout probability has to be between 0 and 1, but got 1.05
return
F
.
dropout
(
x
,
self
.
dropout_rate
()[
0
]
-
.
1
)
model
=
self
.
_convert_to_ir
(
Net
())
mutators
=
process_inline_mutation
(
model
)
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
))
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
test/ut/retiarii/test_strategy.py
View file @
6808708d
...
@@ -62,11 +62,11 @@ class Net(nn.Module):
...
@@ -62,11 +62,11 @@ class Net(nn.Module):
self
.
fc1
=
nn
.
LayerChoice
([
self
.
fc1
=
nn
.
LayerChoice
([
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
,
bias
=
True
),
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
,
bias
=
True
),
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
,
bias
=
False
)
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
,
bias
=
False
)
])
]
,
label
=
'fc1'
)
self
.
fc2
=
nn
.
LayerChoice
([
self
.
fc2
=
nn
.
LayerChoice
([
nn
.
Linear
(
hidden_size
,
10
,
bias
=
False
),
nn
.
Linear
(
hidden_size
,
10
,
bias
=
False
),
nn
.
Linear
(
hidden_size
,
10
,
bias
=
True
)
nn
.
Linear
(
hidden_size
,
10
,
bias
=
True
)
])
]
,
label
=
'fc2'
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
relu
(
self
.
conv1
(
x
))
...
@@ -97,8 +97,8 @@ def test_grid_search():
...
@@ -97,8 +97,8 @@ def test_grid_search():
selection
=
set
()
selection
=
set
()
for
model
in
engine
.
models
:
for
model
in
engine
.
models
:
selection
.
add
((
selection
.
add
((
model
.
g
et_node_by_name
(
'_model__fc1'
)
.
operation
.
parameters
[
'bias'
],
model
.
g
raphs
[
'_model__fc1'
].
hidden_nodes
[
0
]
.
operation
.
parameters
[
'bias'
],
model
.
g
et_node_by_name
(
'_model__fc2'
)
.
operation
.
parameters
[
'bias'
]
model
.
g
raphs
[
'_model__fc2'
].
hidden_nodes
[
0
]
.
operation
.
parameters
[
'bias'
]
))
))
assert
len
(
selection
)
==
4
assert
len
(
selection
)
==
4
_reset_execution_engine
()
_reset_execution_engine
()
...
@@ -113,8 +113,8 @@ def test_random_search():
...
@@ -113,8 +113,8 @@ def test_random_search():
selection
=
set
()
selection
=
set
()
for
model
in
engine
.
models
:
for
model
in
engine
.
models
:
selection
.
add
((
selection
.
add
((
model
.
g
et_node_by_name
(
'_model__fc1'
)
.
operation
.
parameters
[
'bias'
],
model
.
g
raphs
[
'_model__fc1'
].
hidden_nodes
[
0
]
.
operation
.
parameters
[
'bias'
],
model
.
g
et_node_by_name
(
'_model__fc2'
)
.
operation
.
parameters
[
'bias'
]
model
.
g
raphs
[
'_model__fc2'
].
hidden_nodes
[
0
]
.
operation
.
parameters
[
'bias'
]
))
))
assert
len
(
selection
)
==
4
assert
len
(
selection
)
==
4
_reset_execution_engine
()
_reset_execution_engine
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment