Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5df75c33
Unverified
Commit
5df75c33
authored
May 26, 2021
by
Yuge Zhang
Committed by
GitHub
May 26, 2021
Browse files
[Retiarii] New API: Repeat and Cell (#3481)
Co-authored-by:
quzha
<
Quanlu.Zhang@microsoft.com
>
parent
2f3c3951
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
311 additions
and
30 deletions
+311
-30
docs/en_US/NAS/retiarii/ApiReference.rst
docs/en_US/NAS/retiarii/ApiReference.rst
+6
-0
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+10
-0
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+2
-0
nni/retiarii/nn/pytorch/__init__.py
nni/retiarii/nn/pytorch/__init__.py
+1
-0
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+13
-27
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+147
-0
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+52
-1
nni/retiarii/nn/pytorch/utils.py
nni/retiarii/nn/pytorch/utils.py
+17
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+63
-2
No files found.
docs/en_US/NAS/retiarii/ApiReference.rst
View file @
5df75c33
...
...
@@ -18,6 +18,12 @@ Inline Mutation APIs
.. autoclass:: nni.retiarii.nn.pytorch.ChosenInputs
:members:
.. autoclass:: nni.retiarii.nn.pytorch.Repeat
:members:
.. autoclass:: nni.retiarii.nn.pytorch.Cell
:members:
Graph Mutation APIs
-------------------
...
...
nni/retiarii/converter/graph_gen.py
View file @
5df75c33
...
...
@@ -642,6 +642,16 @@ class GraphConverter:
ir_graph
.
_register
()
# add mutation signal for special modules
if
original_type_name
==
OpTypeName
.
Repeat
:
attrs
=
{
'mutation'
:
'repeat'
,
'label'
:
module
.
label
,
'min_depth'
:
module
.
min_depth
,
'max_depth'
:
module
.
max_depth
}
return
ir_graph
,
attrs
return
ir_graph
,
{}
...
...
nni/retiarii/converter/op_types.py
View file @
5df75c33
...
...
@@ -17,3 +17,5 @@ class OpTypeName(str, Enum):
ValueChoice
=
'ValueChoice'
Placeholder
=
'Placeholder'
MergedSlice
=
'MergedSlice'
Repeat
=
'Repeat'
Cell
=
'Cell'
nni/retiarii/nn/pytorch/__init__.py
View file @
5df75c33
from
.api
import
*
from
.component
import
*
from
.nn
import
*
nni/retiarii/nn/pytorch/api.py
View file @
5df75c33
...
...
@@ -10,26 +10,12 @@ import torch
import
torch.nn
as
nn
from
...serializer
import
Translatable
,
basic_unit
from
..
.utils
import
uid
,
get_current_context
from
.utils
import
generate_new_label
,
get_fixed_value
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
def
_generate_new_label
(
label
:
Optional
[
str
]):
if
label
is
None
:
return
'_mutation_'
+
str
(
uid
(
'mutation'
))
return
label
def
_get_fixed_value
(
label
:
str
):
ret
=
get_current_context
(
'fixed'
)
try
:
return
ret
[
_generate_new_label
(
label
)]
except
KeyError
:
raise
KeyError
(
f
'Fixed context with
{
label
}
not found. Existing values are:
{
ret
}
'
)
class
LayerChoice
(
nn
.
Module
):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
...
...
@@ -69,9 +55,9 @@ class LayerChoice(nn.Module):
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def
__new__
(
cls
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
label
:
str
=
None
,
**
kwargs
):
def
__new__
(
cls
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
try
:
chosen
=
_
get_fixed_value
(
label
)
chosen
=
get_fixed_value
(
label
)
if
isinstance
(
candidates
,
list
):
return
candidates
[
int
(
chosen
)]
else
:
...
...
@@ -79,7 +65,7 @@ class LayerChoice(nn.Module):
except
AssertionError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
label
:
str
=
None
,
**
kwargs
):
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
LayerChoice
,
self
).
__init__
()
if
'key'
in
kwargs
:
warnings
.
warn
(
f
'"key" is deprecated. Assuming label.'
)
...
...
@@ -89,7 +75,7 @@ class LayerChoice(nn.Module):
if
'reduction'
in
kwargs
:
warnings
.
warn
(
f
'"reduction" is deprecated. Ignoring...'
)
self
.
candidates
=
candidates
self
.
_label
=
_
generate_new_label
(
label
)
self
.
_label
=
generate_new_label
(
label
)
self
.
names
=
[]
if
isinstance
(
candidates
,
OrderedDict
):
...
...
@@ -187,13 +173,13 @@ class InputChoice(nn.Module):
Identifier of the input choice.
"""
def
__new__
(
cls
,
n_candidates
:
int
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
str
=
None
,
**
kwargs
):
def
__new__
(
cls
,
n_candidates
:
int
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
try
:
return
ChosenInputs
(
_
get_fixed_value
(
label
),
reduction
=
reduction
)
return
ChosenInputs
(
get_fixed_value
(
label
),
reduction
=
reduction
)
except
AssertionError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
str
=
None
,
**
kwargs
):
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
(
InputChoice
,
self
).
__init__
()
if
'key'
in
kwargs
:
warnings
.
warn
(
f
'"key" is deprecated. Assuming label.'
)
...
...
@@ -206,7 +192,7 @@ class InputChoice(nn.Module):
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
assert
self
.
reduction
in
[
'mean'
,
'concat'
,
'sum'
,
'none'
]
self
.
_label
=
_
generate_new_label
(
label
)
self
.
_label
=
generate_new_label
(
label
)
@
property
def
key
(
self
):
...
...
@@ -295,16 +281,16 @@ class ValueChoice(Translatable, nn.Module):
Identifier of the value choice.
"""
def
__new__
(
cls
,
candidates
:
List
[
Any
],
label
:
str
=
None
):
def
__new__
(
cls
,
candidates
:
List
[
Any
],
label
:
Optional
[
str
]
=
None
):
try
:
return
_
get_fixed_value
(
label
)
return
get_fixed_value
(
label
)
except
AssertionError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
candidates
:
List
[
Any
],
label
:
str
=
None
):
def
__init__
(
self
,
candidates
:
List
[
Any
],
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
candidates
=
candidates
self
.
_label
=
_
generate_new_label
(
label
)
self
.
_label
=
generate_new_label
(
label
)
self
.
_accessor
=
[]
@
property
...
...
nni/retiarii/nn/pytorch/component.py
0 → 100644
View file @
5df75c33
import
copy
from
typing
import
Callable
,
List
,
Union
,
Tuple
,
Optional
import
torch
import
torch.nn
as
nn
from
.api
import
LayerChoice
,
InputChoice
from
.nn
import
ModuleList
from
.utils
import
generate_new_label
,
get_fixed_value
__all__
=
[
'Repeat'
,
'Cell'
]
class
Repeat
(
nn
.
Module
):
"""
Repeat a block by a variable number of times.
Parameters
----------
blocks : function, list of function, module or list of module
The block to be repeated. If not a list, it will be replicated into a list.
If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken.
If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied.
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least `min` times and at most `max` times.
"""
def
__new__
(
cls
,
blocks
:
Union
[
Callable
[[],
nn
.
Module
],
List
[
Callable
[[],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
label
:
Optional
[
str
]
=
None
):
try
:
repeat
=
get_fixed_value
(
label
)
return
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
except
AssertionError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
blocks
:
Union
[
Callable
[[],
nn
.
Module
],
List
[
Callable
[[],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
self
.
max_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
1
]
assert
self
.
max_depth
>=
self
.
min_depth
>
0
self
.
blocks
=
nn
.
ModuleList
(
self
.
_replicate_and_instantiate
(
blocks
,
self
.
max_depth
))
@
property
def
label
(
self
):
return
self
.
_label
def
forward
(
self
,
x
):
for
block
in
self
.
blocks
:
x
=
block
(
x
)
return
x
@
staticmethod
def
_replicate_and_instantiate
(
blocks
,
repeat
):
if
not
isinstance
(
blocks
,
list
):
if
isinstance
(
blocks
,
nn
.
Module
):
blocks
=
[
blocks
]
+
[
copy
.
deepcopy
(
blocks
)
for
_
in
range
(
repeat
-
1
)]
else
:
blocks
=
[
blocks
for
_
in
range
(
repeat
)]
assert
len
(
blocks
)
>
0
assert
repeat
<=
len
(
blocks
),
f
'Not enough blocks to be used.
{
repeat
}
expected, only found
{
len
(
blocks
)
}
.'
blocks
=
blocks
[:
repeat
]
if
not
isinstance
(
blocks
[
0
],
nn
.
Module
):
blocks
=
[
b
()
for
b
in
blocks
]
return
blocks
class
Cell
(
nn
.
Module
):
"""
Cell structure [1]_ [2]_ that is popularly used in NAS literature.
A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes).
Parameters
----------
op_candidates : function or list of module
A list of modules to choose from, or a function that returns a list of modules.
num_nodes : int
Number of nodes in the cell.
num_ops_per_node: int
Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1.
num_predecessors : int
Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1.
merge_op : str
Currently only ``all`` is supported, which has slight difference with that described in reference. Default: all.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
References
----------
.. [1] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [2] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
"Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012
"""
# TODO:
# Support loose end concat (shape inference on the following cells)
# How to dynamically create convolution with stride as the first node
def
__init__
(
self
,
op_candidates
:
Union
[
Callable
,
List
[
nn
.
Module
]],
num_nodes
:
int
,
num_ops_per_node
:
int
=
1
,
num_predecessors
:
int
=
1
,
merge_op
:
str
=
'all'
,
label
:
str
=
None
):
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
ops
=
ModuleList
()
self
.
inputs
=
ModuleList
()
self
.
num_nodes
=
num_nodes
self
.
num_ops_per_node
=
num_ops_per_node
self
.
num_predecessors
=
num_predecessors
for
i
in
range
(
num_nodes
):
self
.
ops
.
append
(
ModuleList
())
self
.
inputs
.
append
(
ModuleList
())
for
k
in
range
(
num_ops_per_node
):
if
isinstance
(
op_candidates
,
list
):
assert
len
(
op_candidates
)
>
0
and
isinstance
(
op_candidates
[
0
],
nn
.
Module
)
ops
=
copy
.
deepcopy
(
op_candidates
)
else
:
ops
=
op_candidates
()
self
.
ops
[
-
1
].
append
(
LayerChoice
(
ops
,
label
=
f
'
{
self
.
label
}
__op_
{
i
}
_
{
k
}
'
))
self
.
inputs
[
-
1
].
append
(
InputChoice
(
i
+
num_predecessors
,
1
,
label
=
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
))
assert
merge_op
in
[
'all'
]
# TODO: loose_end
self
.
merge_op
=
merge_op
@
property
def
label
(
self
):
return
self
.
_label
def
forward
(
self
,
x
:
List
[
torch
.
Tensor
]):
states
=
x
for
ops
,
inps
in
zip
(
self
.
ops
,
self
.
inputs
):
current_state
=
[]
for
op
,
inp
in
zip
(
ops
,
inps
):
current_state
.
append
(
op
(
inp
(
states
)))
current_state
=
torch
.
sum
(
torch
.
stack
(
current_state
),
0
)
states
.
append
(
current_state
)
return
torch
.
cat
(
states
[
self
.
num_predecessors
:],
1
)
nni/retiarii/nn/pytorch/mutator.py
View file @
5df75c33
...
...
@@ -8,8 +8,9 @@ import torch.nn as nn
from
...mutator
import
Mutator
from
...graph
import
Cell
,
Graph
,
Model
,
ModelStatus
,
Node
from
...utils
import
uid
from
.api
import
LayerChoice
,
InputChoice
,
ValueChoice
,
Placeholder
from
.component
import
Repeat
from
...utils
import
uid
class
LayerChoiceMutator
(
Mutator
):
...
...
@@ -80,6 +81,42 @@ class ParameterChoiceMutator(Mutator):
target
.
update_operation
(
target
.
operation
.
type
,
{
**
target
.
operation
.
parameters
,
argname
:
chosen_value
})
class
RepeatMutator
(
Mutator
):
def
__init__
(
self
,
nodes
:
List
[
Node
]):
# nodes is a subgraph consisting of repeated blocks.
super
().
__init__
()
self
.
nodes
=
nodes
def
_retrieve_chain_from_graph
(
self
,
graph
:
Graph
)
->
List
[
Node
]:
u
=
graph
.
input_node
chain
=
[]
while
u
!=
graph
.
output_node
:
if
u
!=
graph
.
input_node
:
chain
.
append
(
u
)
assert
len
(
u
.
successors
)
==
1
,
f
'This graph is an illegal chain.
{
u
}
has output
{
u
.
successor
}
.'
u
=
u
.
successors
[
0
]
return
chain
def
mutate
(
self
,
model
):
min_depth
=
self
.
nodes
[
0
].
operation
.
parameters
[
'min_depth'
]
max_depth
=
self
.
nodes
[
0
].
operation
.
parameters
[
'max_depth'
]
if
min_depth
<
max_depth
:
chosen_depth
=
self
.
choice
(
list
(
range
(
min_depth
,
max_depth
+
1
)))
for
node
in
self
.
nodes
:
# the logic here is similar to layer choice. We find cell attached to each node.
target
:
Graph
=
model
.
graphs
[
node
.
operation
.
cell_name
]
chain
=
self
.
_retrieve_chain_from_graph
(
target
)
for
edge
in
chain
[
chosen_depth
-
1
].
outgoing_edges
:
edge
.
remove
()
target
.
add_edge
((
chain
[
chosen_depth
-
1
],
None
),
(
target
.
output_node
,
None
))
for
rm_node
in
chain
[
chosen_depth
:]:
for
edge
in
rm_node
.
outgoing_edges
:
edge
.
remove
()
rm_node
.
remove
()
# to delete the unused parameters.
model
.
get_node_by_name
(
node
.
name
).
update_operation
(
Cell
(
node
.
operation
.
cell_name
))
def
process_inline_mutation
(
model
:
Model
)
->
Optional
[
List
[
Mutator
]]:
applied_mutators
=
[]
...
...
@@ -120,6 +157,15 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator
=
LayerChoiceMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
repeat_nodes
=
_group_by_label
(
filter
(
lambda
d
:
d
.
operation
.
parameters
.
get
(
'mutation'
)
==
'repeat'
,
model
.
get_nodes_by_type
(
'_cell'
)))
for
node_list
in
repeat_nodes
:
assert
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'max_depth'
],
node_list
))
and
\
_is_all_equal
(
map
(
lambda
node
:
node
.
operation
.
parameters
[
'min_depth'
],
node_list
)),
\
'Repeat with the same label must have the same number of candidates.'
mutator
=
RepeatMutator
(
node_list
)
applied_mutators
.
append
(
mutator
)
if
applied_mutators
:
return
applied_mutators
return
None
...
...
@@ -190,6 +236,11 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if
isinstance
(
module
,
ValueChoice
):
node
=
graph
.
add_node
(
name
,
'ValueChoice'
,
{
'candidates'
:
module
.
candidates
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
Repeat
)
and
module
.
min_depth
<=
module
.
max_depth
:
node
=
graph
.
add_node
(
name
,
'Repeat'
,
{
'candidates'
:
list
(
range
(
module
.
min_depth
,
module
.
max_depth
+
1
))
})
node
.
label
=
module
.
label
if
isinstance
(
module
,
Placeholder
):
raise
NotImplementedError
(
'Placeholder is not supported in python execution mode.'
)
...
...
nni/retiarii/nn/pytorch/utils.py
0 → 100644
View file @
5df75c33
from
typing
import
Optional
from
...utils
import
uid
,
get_current_context
def
generate_new_label
(
label
:
Optional
[
str
]):
if
label
is
None
:
return
'_mutation_'
+
str
(
uid
(
'mutation'
))
return
label
def
get_fixed_value
(
label
:
str
):
ret
=
get_current_context
(
'fixed'
)
try
:
return
ret
[
generate_new_label
(
label
)]
except
KeyError
:
raise
KeyError
(
f
'Fixed context with
{
label
}
not found. Existing values are:
{
ret
}
'
)
test/ut/retiarii/test_highlevel_apis.py
View file @
5df75c33
...
...
@@ -379,7 +379,7 @@ class GraphIR(unittest.TestCase):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
dropout_rate
=
nn
.
ValueChoice
([[
0.
,],
[
1.
,]])
self
.
dropout_rate
=
nn
.
ValueChoice
([[
0.
,
],
[
1.
,
]])
def
forward
(
self
,
x
):
return
F
.
dropout
(
x
,
self
.
dropout_rate
()[
0
])
...
...
@@ -398,7 +398,7 @@ class GraphIR(unittest.TestCase):
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
dropout_rate
=
nn
.
ValueChoice
([[
1.05
,],
[
1.1
,]])
self
.
dropout_rate
=
nn
.
ValueChoice
([[
1.05
,
],
[
1.1
,
]])
def
forward
(
self
,
x
):
# if expression failed, the exception would be:
...
...
@@ -414,6 +414,67 @@ class GraphIR(unittest.TestCase):
self
.
assertEqual
(
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
size
(),
torch
.
Size
([
1
,
3
,
3
,
3
]))
self
.
assertAlmostEqual
(
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
randn
(
1
,
3
,
3
,
3
)).
abs
().
sum
().
item
(),
0
)
def
test_repeat
(
self
):
class
AddOne
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
1
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
AddOne
(),
(
3
,
5
))
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
1
)
mutator
=
mutators
[
0
].
bind_sampler
(
EnumerateSampler
())
model1
=
mutator
.
apply
(
model
)
model2
=
mutator
.
apply
(
model
)
model3
=
mutator
.
apply
(
model
)
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model1
)(
torch
.
zeros
(
1
,
16
))
==
3
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
zeros
(
1
,
16
))
==
4
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model3
)(
torch
.
zeros
(
1
,
16
))
==
5
).
all
())
def
test_cell
(
self
):
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
16
,
bias
=
False
)],
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
merge_op
=
'all'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
([
x
,
y
])
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
),
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
@
self
.
get_serializer
()
class
Net2
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
16
,
bias
=
False
)],
num_nodes
=
4
)
def
forward
(
self
,
x
):
return
self
.
cell
([
x
])
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net2
())
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
model
=
raw_model
for
mutator
in
mutators
:
model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
model
)
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
class
Python
(
GraphIR
):
def
_get_converted_pytorch_model
(
self
,
model_ir
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment