Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c80bda29
Unverified
Commit
c80bda29
authored
May 23, 2022
by
Yuge Zhang
Committed by
GitHub
May 23, 2022
Browse files
Support Repeat and Cell in One-shot NAS (#4835)
parent
c54a07df
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
722 additions
and
137 deletions
+722
-137
nni/retiarii/evaluator/pytorch/lightning.py
nni/retiarii/evaluator/pytorch/lightning.py
+2
-1
nni/retiarii/nn/pytorch/cell.py
nni/retiarii/nn/pytorch/cell.py
+58
-33
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+3
-3
nni/retiarii/oneshot/pytorch/base_lightning.py
nni/retiarii/oneshot/pytorch/base_lightning.py
+7
-4
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+4
-11
nni/retiarii/oneshot/pytorch/sampling.py
nni/retiarii/oneshot/pytorch/sampling.py
+6
-1
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
+207
-5
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
+203
-3
test/ut/retiarii/models.py
test/ut/retiarii/models.py
+77
-0
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+9
-73
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+28
-1
test/ut/retiarii/test_oneshot_supermodules.py
test/ut/retiarii/test_oneshot_supermodules.py
+118
-2
No files found.
nni/retiarii/evaluator/pytorch/lightning.py
View file @
c80bda29
...
@@ -247,7 +247,8 @@ class _SupervisedLearningModule(LightningModule):
...
@@ -247,7 +247,8 @@ class _SupervisedLearningModule(LightningModule):
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
# type: ignore
return
self
.
optimizer
(
self
.
parameters
(),
lr
=
self
.
hparams
.
learning_rate
,
weight_decay
=
self
.
hparams
.
weight_decay
)
# type: ignore
def
on_validation_epoch_end
(
self
):
def
on_validation_epoch_end
(
self
):
if
self
.
running_mode
==
'multi'
:
if
not
self
.
trainer
.
sanity_checking
and
self
.
running_mode
==
'multi'
:
# Don't report metric when sanity checking
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
nni
.
report_intermediate_result
(
self
.
_get_validation_metrics
())
def
on_fit_end
(
self
):
def
on_fit_end
(
self
):
...
...
nni/retiarii/nn/pytorch/cell.py
View file @
c80bda29
...
@@ -30,9 +30,51 @@ class _DefaultPostprocessor(nn.Module):
...
@@ -30,9 +30,51 @@ class _DefaultPostprocessor(nn.Module):
return
this_cell
return
this_cell
_c
ell
_op_f
actory
_type
=
Callable
[[
int
,
int
,
Optional
[
int
]],
nn
.
Module
]
C
ell
OpF
actory
=
Callable
[[
int
,
int
,
Optional
[
int
]],
nn
.
Module
]
def
create_cell_op_candidates
(
op_candidates
,
node_index
,
op_index
,
chosen
)
->
Tuple
[
Dict
[
str
,
nn
.
Module
],
bool
]:
has_factory
=
False
# convert the complex type into the type that is acceptable to LayerChoice
def
convert_single_op
(
op
):
nonlocal
has_factory
if
isinstance
(
op
,
nn
.
Module
):
return
copy
.
deepcopy
(
op
)
elif
callable
(
op
):
# Yes! It's using factory to create operations now.
has_factory
=
True
# FIXME: I don't know how to check whether we are in graph engine.
return
op
(
node_index
,
op_index
,
chosen
)
else
:
raise
TypeError
(
f
'Unrecognized type
{
type
(
op
)
}
for op
{
op
}
'
)
if
isinstance
(
op_candidates
,
list
):
res
=
{
str
(
i
):
convert_single_op
(
op
)
for
i
,
op
in
enumerate
(
op_candidates
)}
elif
isinstance
(
op_candidates
,
dict
):
res
=
{
key
:
convert_single_op
(
op
)
for
key
,
op
in
op_candidates
.
items
()}
elif
callable
(
op_candidates
):
warnings
.
warn
(
f
'Directly passing a callable into Cell is deprecated. Please consider migrating to list or dict.'
,
DeprecationWarning
)
res
=
op_candidates
()
has_factory
=
True
else
:
raise
TypeError
(
f
'Unrecognized type
{
type
(
op_candidates
)
}
for
{
op_candidates
}
'
)
return
res
,
has_factory
def
preprocess_cell_inputs
(
num_predecessors
:
int
,
*
inputs
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
if
len
(
inputs
)
==
1
and
isinstance
(
inputs
[
0
],
list
):
processed_inputs
=
list
(
inputs
[
0
])
# shallow copy
else
:
processed_inputs
=
cast
(
List
[
torch
.
Tensor
],
list
(
inputs
))
assert
len
(
processed_inputs
)
==
num_predecessors
,
'The number of inputs must be equal to `num_predecessors`.'
return
processed_inputs
class
Cell
(
nn
.
Module
):
class
Cell
(
nn
.
Module
):
"""
"""
Cell structure that is popularly used in NAS literature.
Cell structure that is popularly used in NAS literature.
...
@@ -108,6 +150,9 @@ class Cell(nn.Module):
...
@@ -108,6 +150,9 @@ class Cell(nn.Module):
The index are enumerated for all nodes including predecessors from 0.
The index are enumerated for all nodes including predecessors from 0.
When first created, the input index is ``None``, meaning unknown.
When first created, the input index is ``None``, meaning unknown.
Note that in graph execution engine, support of function in ``op_candidates`` is limited.
Note that in graph execution engine, support of function in ``op_candidates`` is limited.
Please also note that, to make :class:`Cell` work with one-shot strategy,
``op_candidates``, in case it's a callable, should not depend on the second input argument,
i.e., ``op_index`` in current node.
num_nodes : int
num_nodes : int
Number of nodes in the cell.
Number of nodes in the cell.
num_ops_per_node: int
num_ops_per_node: int
...
@@ -191,15 +236,19 @@ class Cell(nn.Module):
...
@@ -191,15 +236,19 @@ class Cell(nn.Module):
When ``merge_op`` is ``loose_end``, ``output_node_indices`` is useful to compute the shape of this cell's output,
When ``merge_op`` is ``loose_end``, ``output_node_indices`` is useful to compute the shape of this cell's output,
because the output shape depends on the connection in the cell, and which nodes are "loose ends" depends on mutation.
because the output shape depends on the connection in the cell, and which nodes are "loose ends" depends on mutation.
op_candidates_factory : CellOpFactory or None
If the operations are created with a factory (callable), this is to be set with the factory.
One-shot algorithms will use this to make each node a cartesian product of operations and inputs.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
op_candidates
:
Union
[
op_candidates
:
Union
[
Callable
[[],
List
[
nn
.
Module
]],
Callable
[[],
List
[
nn
.
Module
]],
List
[
nn
.
Module
],
List
[
nn
.
Module
],
List
[
_c
ell
_op_f
actory
_type
],
List
[
C
ell
OpF
actory
],
Dict
[
str
,
nn
.
Module
],
Dict
[
str
,
nn
.
Module
],
Dict
[
str
,
_c
ell
_op_f
actory
_type
]
Dict
[
str
,
C
ell
OpF
actory
]
],
],
num_nodes
:
int
,
num_nodes
:
int
,
num_ops_per_node
:
int
=
1
,
num_ops_per_node
:
int
=
1
,
...
@@ -232,6 +281,8 @@ class Cell(nn.Module):
...
@@ -232,6 +281,8 @@ class Cell(nn.Module):
self
.
concat_dim
=
concat_dim
self
.
concat_dim
=
concat_dim
self
.
op_candidates_factory
:
Union
[
List
[
CellOpFactory
],
Dict
[
str
,
CellOpFactory
],
None
]
=
None
# set later
# fill-in the missing modules
# fill-in the missing modules
self
.
_create_modules
(
op_candidates
)
self
.
_create_modules
(
op_candidates
)
...
@@ -253,7 +304,9 @@ class Cell(nn.Module):
...
@@ -253,7 +304,9 @@ class Cell(nn.Module):
# this is needed because op_candidates can be very complex
# this is needed because op_candidates can be very complex
# the type annoation and docs for details
# the type annoation and docs for details
ops
=
self
.
_convert_op_candidates
(
op_candidates
,
i
,
k
,
chosen
)
ops
,
has_factory
=
create_cell_op_candidates
(
op_candidates
,
i
,
k
,
chosen
)
if
has_factory
:
self
.
op_candidates_factory
=
op_candidates
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
cast
(
ModuleList
,
self
.
ops
[
-
1
]).
append
(
LayerChoice
(
ops
,
label
=
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
))
cast
(
ModuleList
,
self
.
ops
[
-
1
]).
append
(
LayerChoice
(
ops
,
label
=
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
))
...
@@ -279,12 +332,7 @@ class Cell(nn.Module):
...
@@ -279,12 +332,7 @@ class Cell(nn.Module):
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
of some of (possibly all) the nodes' outputs in the cell.
"""
"""
processed_inputs
:
List
[
torch
.
Tensor
]
processed_inputs
:
List
[
torch
.
Tensor
]
=
preprocess_cell_inputs
(
self
.
num_predecessors
,
*
inputs
)
if
len
(
inputs
)
==
1
and
isinstance
(
inputs
[
0
],
list
):
processed_inputs
=
list
(
inputs
[
0
])
# shallow copy
else
:
processed_inputs
=
cast
(
List
[
torch
.
Tensor
],
list
(
inputs
))
assert
len
(
processed_inputs
)
==
self
.
num_predecessors
,
'The number of inputs must be equal to `num_predecessors`.'
states
:
List
[
torch
.
Tensor
]
=
self
.
preprocessor
(
processed_inputs
)
states
:
List
[
torch
.
Tensor
]
=
self
.
preprocessor
(
processed_inputs
)
for
ops
,
inps
in
zip
(
for
ops
,
inps
in
zip
(
cast
(
Sequence
[
Sequence
[
LayerChoice
]],
self
.
ops
),
cast
(
Sequence
[
Sequence
[
LayerChoice
]],
self
.
ops
),
...
@@ -301,26 +349,3 @@ class Cell(nn.Module):
...
@@ -301,26 +349,3 @@ class Cell(nn.Module):
else
:
else
:
this_cell
=
torch
.
cat
([
states
[
k
]
for
k
in
self
.
output_node_indices
],
self
.
concat_dim
)
this_cell
=
torch
.
cat
([
states
[
k
]
for
k
in
self
.
output_node_indices
],
self
.
concat_dim
)
return
self
.
postprocessor
(
this_cell
,
processed_inputs
)
return
self
.
postprocessor
(
this_cell
,
processed_inputs
)
@
staticmethod
def
_convert_op_candidates
(
op_candidates
,
node_index
,
op_index
,
chosen
)
->
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]]:
# convert the complex type into the type that is acceptable to LayerChoice
def
convert_single_op
(
op
):
if
isinstance
(
op
,
nn
.
Module
):
return
copy
.
deepcopy
(
op
)
elif
callable
(
op
):
# FIXME: I don't know how to check whether we are in graph engine.
return
op
(
node_index
,
op_index
,
chosen
)
else
:
raise
TypeError
(
f
'Unrecognized type
{
type
(
op
)
}
for op
{
op
}
'
)
if
isinstance
(
op_candidates
,
list
):
return
[
convert_single_op
(
op
)
for
op
in
op_candidates
]
elif
isinstance
(
op_candidates
,
dict
):
return
{
key
:
convert_single_op
(
op
)
for
key
,
op
in
op_candidates
.
items
()}
elif
callable
(
op_candidates
):
warnings
.
warn
(
f
'Directly passing a callable into Cell is deprecated. Please consider migrating to list or dict.'
,
DeprecationWarning
)
return
op_candidates
()
else
:
raise
TypeError
(
f
'Unrecognized type
{
type
(
op_candidates
)
}
for
{
op_candidates
}
'
)
nni/retiarii/nn/pytorch/component.py
View file @
c80bda29
...
@@ -106,7 +106,7 @@ class Repeat(Mutable):
...
@@ -106,7 +106,7 @@ class Repeat(Mutable):
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.'
,
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.'
,
RuntimeWarning
RuntimeWarning
)
)
self
.
depth_choice
=
depth
self
.
depth_choice
:
Union
[
int
,
ChoiceOf
[
int
]]
=
depth
all_values
=
list
(
self
.
depth_choice
.
all_options
())
all_values
=
list
(
self
.
depth_choice
.
all_options
())
self
.
min_depth
=
min
(
all_values
)
self
.
min_depth
=
min
(
all_values
)
self
.
max_depth
=
max
(
all_values
)
self
.
max_depth
=
max
(
all_values
)
...
@@ -117,12 +117,12 @@ class Repeat(Mutable):
...
@@ -117,12 +117,12 @@ class Repeat(Mutable):
elif
isinstance
(
depth
,
tuple
):
elif
isinstance
(
depth
,
tuple
):
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
self
.
max_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
1
]
self
.
max_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
1
]
self
.
depth_choice
=
ValueChoice
(
list
(
range
(
self
.
min_depth
,
self
.
max_depth
+
1
)),
label
=
label
)
self
.
depth_choice
:
Union
[
int
,
ChoiceOf
[
int
]]
=
ValueChoice
(
list
(
range
(
self
.
min_depth
,
self
.
max_depth
+
1
)),
label
=
label
)
self
.
_label
=
self
.
depth_choice
.
label
self
.
_label
=
self
.
depth_choice
.
label
elif
isinstance
(
depth
,
int
):
elif
isinstance
(
depth
,
int
):
self
.
min_depth
=
self
.
max_depth
=
depth
self
.
min_depth
=
self
.
max_depth
=
depth
self
.
depth_choice
=
depth
self
.
depth_choice
:
Union
[
int
,
ChoiceOf
[
int
]]
=
depth
else
:
else
:
raise
TypeError
(
f
'Unsupported "depth" type:
{
type
(
depth
)
}
'
)
raise
TypeError
(
f
'Unsupported "depth" type:
{
type
(
depth
)
}
'
)
assert
self
.
max_depth
>=
self
.
min_depth
>=
0
and
self
.
max_depth
>=
1
,
f
'Depth of
{
self
.
min_depth
}
to
{
self
.
max_depth
}
is invalid.'
assert
self
.
max_depth
>=
self
.
min_depth
>=
0
and
self
.
max_depth
>=
1
,
f
'Depth of
{
self
.
min_depth
}
to
{
self
.
max_depth
}
is invalid.'
...
...
nni/retiarii/oneshot/pytorch/base_lightning.py
View file @
c80bda29
...
@@ -59,7 +59,8 @@ def traverse_and_mutate_submodules(
...
@@ -59,7 +59,8 @@ def traverse_and_mutate_submodules(
module_list
=
[]
module_list
=
[]
def
apply
(
m
):
def
apply
(
m
):
for
name
,
child
in
m
.
named_children
():
# Need to call list() here because the loop body might replace some children in-place.
for
name
,
child
in
list
(
m
.
named_children
()):
# post-order DFS
# post-order DFS
if
not
topdown
:
if
not
topdown
:
apply
(
child
)
apply
(
child
)
...
@@ -94,6 +95,8 @@ def traverse_and_mutate_submodules(
...
@@ -94,6 +95,8 @@ def traverse_and_mutate_submodules(
break
break
if
isinstance
(
mutate_result
,
BaseSuperNetModule
):
if
isinstance
(
mutate_result
,
BaseSuperNetModule
):
# Replace child with the mutate result, and DFS this one
child
=
mutate_result
module_list
.
append
(
mutate_result
)
module_list
.
append
(
mutate_result
)
# pre-order DFS
# pre-order DFS
...
@@ -112,9 +115,9 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
...
@@ -112,9 +115,9 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
primitive_list
=
(
primitive_list
=
(
nas_nn
.
LayerChoice
,
nas_nn
.
LayerChoice
,
nas_nn
.
InputChoice
,
nas_nn
.
InputChoice
,
nas_nn
.
ValueChoice
,
nas_nn
.
Repeat
,
nas_nn
.
Repeat
,
nas_nn
.
NasBench101Cell
,
nas_nn
.
NasBench101Cell
,
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
# nas_nn.NasBench201Cell, # forward = supernet
)
)
...
...
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
c80bda29
...
@@ -12,7 +12,8 @@ import torch.optim as optim
...
@@ -12,7 +12,8 @@ import torch.optim as optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.supermodule.differentiable
import
(
from
.supermodule.differentiable
import
(
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
MixedOpDifferentiablePolicy
,
GumbelSoftmax
MixedOpDifferentiablePolicy
,
GumbelSoftmax
,
DifferentiableMixedCell
,
DifferentiableMixedRepeat
)
)
from
.supermodule.proxyless
import
ProxylessMixedInput
,
ProxylessMixedLayer
from
.supermodule.proxyless
import
ProxylessMixedInput
,
ProxylessMixedLayer
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
...
@@ -52,6 +53,8 @@ class DartsLightningModule(BaseOneShotLightningModule):
...
@@ -52,6 +53,8 @@ class DartsLightningModule(BaseOneShotLightningModule):
hooks
=
[
hooks
=
[
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedInput
.
mutate
,
DifferentiableMixedInput
.
mutate
,
DifferentiableMixedCell
.
mutate
,
DifferentiableMixedRepeat
.
mutate
,
]
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
.
append
(
no_default_hook
)
hooks
.
append
(
no_default_hook
)
...
@@ -182,16 +185,6 @@ class GumbelDartsLightningModule(DartsLightningModule):
...
@@ -182,16 +185,6 @@ class GumbelDartsLightningModule(DartsLightningModule):
Learning rate for architecture optimizer. Default: 3.0e-4
Learning rate for architecture optimizer. Default: 3.0e-4
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
def
default_mutation_hooks
(
self
)
->
list
[
MutationHook
]:
"""Replace modules with gumbel-differentiable versions"""
hooks
=
[
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedInput
.
mutate
,
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
.
append
(
no_default_hook
)
return
hooks
def
mutate_kwargs
(
self
):
def
mutate_kwargs
(
self
):
"""Use gumbel softmax."""
"""Use gumbel softmax."""
return
{
return
{
...
...
nni/retiarii/oneshot/pytorch/sampling.py
View file @
c80bda29
...
@@ -12,7 +12,10 @@ import torch.nn as nn
...
@@ -12,7 +12,10 @@ import torch.nn as nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.base_lightning
import
BaseOneShotLightningModule
,
MutationHook
,
no_default_hook
from
.supermodule.sampling
import
PathSamplingInput
,
PathSamplingLayer
,
MixedOpPathSamplingPolicy
from
.supermodule.sampling
import
(
PathSamplingInput
,
PathSamplingLayer
,
MixedOpPathSamplingPolicy
,
PathSamplingCell
,
PathSamplingRepeat
)
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
from
.supermodule.operation
import
NATIVE_MIXED_OPERATIONS
from
.enas
import
ReinforceController
,
ReinforceField
from
.enas
import
ReinforceController
,
ReinforceField
...
@@ -43,6 +46,8 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
...
@@ -43,6 +46,8 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
hooks
=
[
hooks
=
[
PathSamplingLayer
.
mutate
,
PathSamplingLayer
.
mutate
,
PathSamplingInput
.
mutate
,
PathSamplingInput
.
mutate
,
PathSamplingRepeat
.
mutate
,
PathSamplingCell
.
mutate
,
]
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
+=
[
operation
.
mutate
for
operation
in
NATIVE_MIXED_OPERATIONS
]
hooks
.
append
(
no_default_hook
)
hooks
.
append
(
no_default_hook
)
...
...
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
View file @
c80bda29
...
@@ -4,20 +4,26 @@
...
@@ -4,20 +4,26 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
functools
import
functools
import
logging
import
warnings
import
warnings
from
typing
import
Any
,
cast
from
typing
import
Any
,
Dict
,
Sequence
,
List
,
Tuple
,
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ChoiceOf
,
Repeat
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.retiarii.nn.pytorch.cell
import
preprocess_cell_inputs
from
.base
import
BaseSuperNetModule
from
.base
import
BaseSuperNetModule
from
.operation
import
MixedOperation
,
MixedOperationSamplingPolicy
from
.operation
import
MixedOperation
,
MixedOperationSamplingPolicy
from
._valuechoice_utils
import
traverse_all_options
from
.sampling
import
PathSamplingCell
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
_logger
=
logging
.
getLogger
(
__name__
)
class
GumbelSoftmax
(
nn
.
Softmax
):
class
GumbelSoftmax
(
nn
.
Softmax
):
...
@@ -284,10 +290,10 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
...
@@ -284,10 +290,10 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
return
{}
return
{}
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
def
export
(
self
,
operation
:
MixedOperation
,
memo
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
"""Export is a
lso random
for each leaf value choice."""
"""Export is a
rgmax
for each leaf value choice."""
result
=
{}
result
=
{}
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
for
name
,
spec
in
operation
.
search_space_spec
().
items
():
if
name
in
result
:
if
name
in
memo
:
continue
continue
chosen_index
=
int
(
torch
.
argmax
(
cast
(
dict
,
operation
.
_arch_alpha
)[
name
]).
item
())
chosen_index
=
int
(
torch
.
argmax
(
cast
(
dict
,
operation
.
_arch_alpha
)[
name
]).
item
())
result
[
name
]
=
spec
.
values
[
chosen_index
]
result
[
name
]
=
spec
.
values
[
chosen_index
]
...
@@ -300,3 +306,199 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
...
@@ -300,3 +306,199 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
}
}
return
dict
(
traverse_all_options
(
operation
.
mutable_arguments
[
name
],
weights
=
weights
))
return
dict
(
traverse_all_options
(
operation
.
mutable_arguments
[
name
],
weights
=
weights
))
return
operation
.
init_arguments
[
name
]
return
operation
.
init_arguments
[
name
]
class
DifferentiableMixedRepeat
(
BaseSuperNetModule
):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
"""
_arch_parameter_names
:
list
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
blocks
:
list
[
nn
.
Module
],
depth
:
ChoiceOf
[
int
],
softmax
:
nn
.
Module
,
memo
:
dict
[
str
,
Any
]):
super
().
__init__
()
self
.
blocks
=
blocks
self
.
depth
=
depth
self
.
_softmax
=
softmax
self
.
_space_spec
:
dict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
([
depth
])
self
.
_arch_alpha
=
nn
.
ParameterDict
()
for
name
,
spec
in
self
.
_space_spec
.
items
():
if
name
in
memo
:
alpha
=
memo
[
name
]
if
len
(
alpha
)
!=
spec
.
size
:
raise
ValueError
(
f
'Architecture parameter size of same label
{
name
}
conflict:
{
len
(
alpha
)
}
vs.
{
spec
.
size
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
spec
.
size
)
*
1E-3
)
self
.
_arch_alpha
[
name
]
=
alpha
def
resample
(
self
,
memo
):
"""Do nothing."""
return
{}
def
export
(
self
,
memo
):
"""Choose argmax for each leaf value choice."""
result
=
{}
for
name
,
spec
in
self
.
_space_spec
.
items
():
if
name
in
memo
:
continue
chosen_index
=
int
(
torch
.
argmax
(
self
.
_arch_alpha
[
name
]).
item
())
result
[
name
]
=
spec
.
values
[
chosen_index
]
return
result
def
search_space_spec
(
self
):
return
self
.
_space_spec
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
Repeat
)
and
isinstance
(
module
.
depth_choice
,
ValueChoiceX
):
# Only interesting when depth is mutable
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
cast
(
List
[
nn
.
Module
],
module
.
blocks
),
module
.
depth_choice
,
softmax
,
memo
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
().
named_parameters
(
*
args
,
**
kwargs
):
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
def
forward
(
self
,
x
):
weights
:
dict
[
str
,
torch
.
Tensor
]
=
{
label
:
self
.
_softmax
(
alpha
)
for
label
,
alpha
in
self
.
_arch_alpha
.
items
()
}
depth_weights
=
dict
(
cast
(
List
[
Tuple
[
int
,
float
]],
traverse_all_options
(
self
.
depth
,
weights
=
weights
)))
res
:
torch
.
Tensor
|
None
=
None
for
i
,
block
in
enumerate
(
self
.
blocks
,
start
=
1
):
# start=1 because depths are 1, 2, 3, 4...
x
=
block
(
x
)
if
i
in
depth_weights
:
if
res
is
None
:
res
=
depth_weights
[
i
]
*
x
else
:
res
=
res
+
depth_weights
[
i
]
*
x
return
res
class
DifferentiableMixedCell
(
PathSamplingCell
):
"""Implementation of Cell under differentiable context.
An architecture parameter is created on each edge of the full-connected graph.
"""
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
# Possibly need another refactor here.
def
__init__
(
self
,
op_factory
,
num_nodes
,
num_ops_per_node
,
num_predecessors
,
preprocessor
,
postprocessor
,
concat_dim
,
memo
,
mutate_kwargs
,
label
):
super
().
__init__
(
op_factory
,
num_nodes
,
num_ops_per_node
,
num_predecessors
,
preprocessor
,
postprocessor
,
concat_dim
,
memo
,
mutate_kwargs
,
label
)
self
.
_arch_alpha
=
nn
.
ParameterDict
()
for
i
in
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
):
for
j
in
range
(
i
):
edge_label
=
f
'
{
label
}
/
{
i
}
_
{
j
}
'
op
=
cast
(
List
[
Dict
[
str
,
nn
.
Module
]],
self
.
ops
[
i
-
self
.
num_predecessors
])[
j
]
if
edge_label
in
memo
:
alpha
=
memo
[
edge_label
]
if
len
(
alpha
)
!=
len
(
op
):
raise
ValueError
(
f
'Architecture parameter size of same label
{
edge_label
}
conflict: '
f
'
{
len
(
alpha
)
}
vs.
{
len
(
op
)
}
'
)
else
:
alpha
=
nn
.
Parameter
(
torch
.
randn
(
len
(
op
))
*
1E-3
)
self
.
_arch_alpha
[
edge_label
]
=
alpha
self
.
_softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
def
resample
(
self
,
memo
):
"""Differentiable doesn't need to resample."""
return
{}
def
export
(
self
,
memo
):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported
=
{}
for
i
in
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
):
# Tuple of (weight, input_index, op_name)
all_weights
:
list
[
tuple
[
float
,
int
,
str
]]
=
[]
for
j
in
range
(
i
):
for
k
,
name
in
enumerate
(
self
.
op_names
):
all_weights
.
append
((
float
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
][
k
].
item
()),
j
,
name
,
))
all_weights
.
sort
(
reverse
=
True
)
# We first prefer inputs from different input_index.
# If we have got no other choices, we start to accept duplicates.
# Therefore we gather first occurrences of distinct input_index to the front.
first_occurrence_index
:
list
[
int
]
=
[
all_weights
.
index
(
# The index of
next
(
filter
(
lambda
t
:
t
[
1
]
==
j
,
all_weights
))
# First occurence of j
)
for
j
in
range
(
i
)
# For j < i
]
first_occurrence_index
.
sort
()
# Keep them ordered too.
all_weights
=
[
all_weights
[
k
]
for
k
in
first_occurrence_index
]
+
\
[
w
for
j
,
w
in
enumerate
(
all_weights
)
if
j
not
in
first_occurrence_index
]
_logger
.
info
(
'Sorted weights in differentiable cell export (node %d): %s'
,
i
,
all_weights
)
for
k
in
range
(
self
.
num_ops_per_node
):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
_
,
j
,
op_name
=
all_weights
[
k
%
len
(
all_weights
)]
exported
[
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
]
=
op_name
exported
[
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
]
=
j
return
exported
def
forward
(
self
,
*
inputs
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]
|
torch
.
Tensor
:
processed_inputs
:
list
[
torch
.
Tensor
]
=
preprocess_cell_inputs
(
self
.
num_predecessors
,
*
inputs
)
states
:
list
[
torch
.
Tensor
]
=
self
.
preprocessor
(
processed_inputs
)
for
i
,
ops
in
enumerate
(
cast
(
Sequence
[
Sequence
[
Dict
[
str
,
nn
.
Module
]]],
self
.
ops
),
start
=
self
.
num_predecessors
):
current_state
=
[]
for
j
in
range
(
i
):
# for every previous tensors
op_results
=
torch
.
stack
([
op
(
states
[
j
])
for
op
in
ops
[
j
].
values
()])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
edge_sum
=
torch
.
sum
(
op_results
*
self
.
_softmax
(
self
.
_arch_alpha
[
f
'
{
self
.
label
}
/
{
i
}
_
{
j
}
'
]).
view
(
*
alpha_shape
),
0
)
current_state
.
append
(
edge_sum
)
states
.
append
(
sum
(
current_state
))
# type: ignore
# Always merge all
this_cell
=
torch
.
cat
(
states
[
self
.
num_predecessors
:],
self
.
concat_dim
)
return
self
.
postprocessor
(
this_cell
,
processed_inputs
)
def
parameters
(
self
,
*
args
,
**
kwargs
):
for
_
,
p
in
self
.
named_parameters
(
*
args
,
**
kwargs
):
yield
p
def
named_parameters
(
self
,
*
args
,
**
kwargs
):
arch
=
kwargs
.
pop
(
'arch'
,
False
)
for
name
,
p
in
super
().
named_parameters
(
*
args
,
**
kwargs
):
if
any
(
name
.
startswith
(
par_name
)
for
par_name
in
MixedOpDifferentiablePolicy
.
_arch_parameter_names
):
if
arch
:
yield
name
,
p
else
:
if
not
arch
:
yield
name
,
p
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
View file @
c80bda29
...
@@ -3,17 +3,20 @@
...
@@ -3,17 +3,20 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
copy
import
random
import
random
from
typing
import
Any
from
typing
import
Any
,
List
,
Dict
,
Sequence
,
cast
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
Repeat
,
ChoiceOf
,
Cell
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.retiarii.nn.pytorch.cell
import
CellOpFactory
,
create_cell_op_candidates
,
preprocess_cell_inputs
from
.base
import
BaseSuperNetModule
from
.base
import
BaseSuperNetModule
from
._valuechoice_utils
import
evaluate_value_choice_with_dict
from
._valuechoice_utils
import
evaluate_value_choice_with_dict
,
dedup_inner_choices
from
.operation
import
MixedOperationSamplingPolicy
,
MixedOperation
from
.operation
import
MixedOperationSamplingPolicy
,
MixedOperation
...
@@ -198,3 +201,200 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
...
@@ -198,3 +201,200 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
if
name
in
operation
.
mutable_arguments
:
if
name
in
operation
.
mutable_arguments
:
return
self
.
_sampled
[
name
]
return
self
.
_sampled
[
name
]
return
operation
.
init_arguments
[
name
]
return
operation
.
init_arguments
[
name
]
class
PathSamplingRepeat
(
BaseSuperNetModule
):
"""
Implementaion of Repeat in a path-sampling supernet.
Samples one / some of the prefixes of the repeated blocks.
Attributes
----------
_sampled : int or list of int
Sampled depth.
"""
def
__init__
(
self
,
blocks
:
list
[
nn
.
Module
],
depth
:
ChoiceOf
[
int
]):
super
().
__init__
()
self
.
blocks
=
blocks
self
.
depth
=
depth
self
.
_space_spec
:
dict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
([
depth
])
self
.
_sampled
:
list
[
int
]
|
int
|
None
=
None
def
resample
(
self
,
memo
):
"""Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices."""
result
=
{}
for
label
in
self
.
_space_spec
:
if
label
in
memo
:
result
[
label
]
=
memo
[
label
]
else
:
result
[
label
]
=
random
.
choice
(
self
.
_space_spec
[
label
].
values
)
self
.
_sampled
=
evaluate_value_choice_with_dict
(
self
.
depth
,
result
)
return
result
def
export
(
self
,
memo
):
"""Random choose one if every choice not in memo."""
result
=
{}
for
label
in
self
.
_space_spec
:
if
label
not
in
memo
:
result
[
label
]
=
random
.
choice
(
self
.
_space_spec
[
label
].
values
)
return
result
def
search_space_spec
(
self
):
return
self
.
_space_spec
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
Repeat
)
and
isinstance
(
module
.
depth_choice
,
ValueChoiceX
):
# Only interesting when depth is mutable
return
cls
(
cast
(
List
[
nn
.
Module
],
module
.
blocks
),
module
.
depth_choice
)
def
forward
(
self
,
x
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one depth needs to be sampled before fprop.'
)
if
isinstance
(
self
.
_sampled
,
list
):
res
=
[]
for
i
,
block
in
enumerate
(
self
.
blocks
):
x
=
block
(
x
)
if
i
in
self
.
_sampled
:
res
.
append
(
x
)
return
sum
(
res
)
else
:
for
block
in
self
.
blocks
[:
self
.
_sampled
]:
x
=
block
(
x
)
return
x
class
PathSamplingCell
(
BaseSuperNetModule
):
"""The implementation of super-net cell follows `DARTS <https://github.com/quark0/darts>`__.
When ``factory_used`` is true, it reconstructs the cell for every possible combination of operation and input index,
because for different input index, the cell factory could instantiate different operations (e.g., with different stride).
On export, we first have best (operation, input) pairs, the select the best ``num_ops_per_node``.
``loose_end`` is not supported yet, because it will cause more problems (e.g., shape mismatch).
We assumes ``loose_end`` to be ``all`` regardless of its configuration.
A supernet cell can't slim its own weight to fit into a sub network, which is also a known issue.
"""
def
__init__
(
self
,
op_factory
:
list
[
CellOpFactory
]
|
dict
[
str
,
CellOpFactory
],
num_nodes
:
int
,
num_ops_per_node
:
int
,
num_predecessors
:
int
,
preprocessor
:
Any
,
postprocessor
:
Any
,
concat_dim
:
int
,
memo
:
dict
,
# although not used here, useful in subclass
mutate_kwargs
:
dict
,
# same as memo
label
:
str
,
):
super
().
__init__
()
self
.
num_nodes
=
num_nodes
self
.
num_ops_per_node
=
num_ops_per_node
self
.
num_predecessors
=
num_predecessors
self
.
preprocessor
=
preprocessor
self
.
ops
=
nn
.
ModuleList
()
self
.
postprocessor
=
postprocessor
self
.
concat_dim
=
concat_dim
self
.
op_names
:
list
[
str
]
=
cast
(
List
[
str
],
None
)
self
.
output_node_indices
=
list
(
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
))
# Create a fully-connected graph.
# Each edge is a ModuleDict with op candidates.
# Can not reuse LayerChoice here, because the spec, resample, export all need to be customized.
# InputChoice is implicit in this graph.
for
i
in
self
.
output_node_indices
:
self
.
ops
.
append
(
nn
.
ModuleList
())
for
k
in
range
(
i
+
self
.
num_predecessors
):
# Second argument in (i, **0**, k) is always 0.
# One-shot strategy can't handle the cases where op spec is dependent on `op_index`.
ops
,
_
=
create_cell_op_candidates
(
op_factory
,
i
,
0
,
k
)
self
.
op_names
=
list
(
ops
.
keys
())
cast
(
nn
.
ModuleList
,
self
.
ops
[
-
1
]).
append
(
nn
.
ModuleDict
(
ops
))
self
.
label
=
label
self
.
_sampled
:
dict
[
str
,
str
|
int
]
=
{}
def
search_space_spec
(
self
)
->
dict
[
str
,
ParameterSpec
]:
# TODO: Recreating the space here.
# The spec should be moved to definition of Cell itself.
space_spec
=
{}
for
i
in
range
(
self
.
num_predecessors
,
self
.
num_nodes
+
self
.
num_predecessors
):
for
k
in
range
(
self
.
num_ops_per_node
):
op_label
=
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
input_label
=
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
space_spec
[
op_label
]
=
ParameterSpec
(
op_label
,
'choice'
,
self
.
op_names
,
(
op_label
,),
True
,
size
=
len
(
self
.
op_names
))
space_spec
[
input_label
]
=
ParameterSpec
(
input_label
,
'choice'
,
list
(
range
(
i
)),
(
input_label
,
),
True
,
size
=
i
)
return
space_spec
def
resample
(
self
,
memo
):
"""Random choose one path if label is not found in memo."""
self
.
_sampled
=
{}
new_sampled
=
{}
for
label
,
param_spec
in
self
.
search_space_spec
().
items
():
if
label
in
memo
:
assert
not
isinstance
(
memo
[
label
],
list
),
'Multi-path sampling is currently unsupported on cell.'
self
.
_sampled
[
label
]
=
memo
[
label
]
else
:
self
.
_sampled
[
label
]
=
new_sampled
[
label
]
=
random
.
choice
(
param_spec
.
values
)
return
new_sampled
def
export
(
self
,
memo
):
"""Randomly choose one to export."""
return
self
.
resample
(
memo
)
def
forward
(
self
,
*
inputs
:
list
[
torch
.
Tensor
]
|
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
...]
|
torch
.
Tensor
:
processed_inputs
:
List
[
torch
.
Tensor
]
=
preprocess_cell_inputs
(
self
.
num_predecessors
,
*
inputs
)
states
:
List
[
torch
.
Tensor
]
=
self
.
preprocessor
(
processed_inputs
)
for
i
,
ops
in
enumerate
(
cast
(
Sequence
[
Sequence
[
Dict
[
str
,
nn
.
Module
]]],
self
.
ops
),
start
=
self
.
num_predecessors
):
current_state
=
[]
for
k
in
range
(
self
.
num_ops_per_node
):
# Select op list based on the input chosen
input_index
=
self
.
_sampled
[
f
'
{
self
.
label
}
/input_
{
i
}
_
{
k
}
'
]
op_candidates
=
ops
[
cast
(
int
,
input_index
)]
# Select op from op list based on the op chosen
op_index
=
self
.
_sampled
[
f
'
{
self
.
label
}
/op_
{
i
}
_
{
k
}
'
]
op
=
op_candidates
[
cast
(
str
,
op_index
)]
current_state
.
append
(
op
(
states
[
cast
(
int
,
input_index
)]))
states
.
append
(
sum
(
current_state
))
# type: ignore
# Always merge all
this_cell
=
torch
.
cat
(
states
[
self
.
num_predecessors
:],
self
.
concat_dim
)
return
self
.
postprocessor
(
this_cell
,
processed_inputs
)
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
Cell
):
op_factory
=
None
# not all the cells need to be replaced
if
module
.
op_candidates_factory
is
not
None
:
op_factory
=
module
.
op_candidates_factory
assert
isinstance
(
op_factory
,
list
)
or
isinstance
(
op_factory
,
dict
),
\
'Only support op_factory of type list or dict.'
elif
module
.
merge_op
==
'loose_end'
:
op_candidates_lc
=
module
.
ops
[
-
1
][
-
1
]
# type: ignore
assert
isinstance
(
op_candidates_lc
,
LayerChoice
)
op_factory
=
{
# create a factory
name
:
lambda
_
,
__
,
___
:
copy
.
deepcopy
(
op_candidates_lc
[
name
])
for
name
in
op_candidates_lc
.
names
}
if
op_factory
is
not
None
:
return
cls
(
op_factory
,
module
.
num_nodes
,
module
.
num_ops_per_node
,
module
.
num_predecessors
,
module
.
preprocessor
,
module
.
postprocessor
,
module
.
concat_dim
,
memo
,
mutate_kwargs
,
module
.
label
)
test/ut/retiarii/models.py
0 → 100644
View file @
c80bda29
from
typing
import
List
,
Tuple
import
torch
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
model_wrapper
@
model_wrapper
class
CellSimple
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
16
,
bias
=
False
)],
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
merge_op
=
'all'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
(
x
,
y
)
@
model_wrapper
class
CellDefaultArgs
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
16
,
bias
=
False
)],
num_nodes
=
4
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
class
CellPreprocessor
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
3
,
16
)
def
forward
(
self
,
x
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
return
[
self
.
linear
(
x
[
0
]),
x
[
1
]]
class
CellPostprocessor
(
nn
.
Module
):
def
forward
(
self
,
this
:
torch
.
Tensor
,
prev
:
List
[
torch
.
Tensor
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
prev
[
-
1
],
this
@
model_wrapper
class
CellCustomProcessor
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
({
'first'
:
nn
.
Linear
(
16
,
16
),
'second'
:
nn
.
Linear
(
16
,
16
,
bias
=
False
)
},
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
preprocessor
=
CellPreprocessor
(),
postprocessor
=
CellPostprocessor
(),
merge_op
=
'all'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
([
x
,
y
])
@
model_wrapper
class
CellLooseEnd
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
16
,
bias
=
False
)],
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
merge_op
=
'loose_end'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
([
x
,
y
])
@
model_wrapper
class
CellOpFactory
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
({
'first'
:
lambda
_
,
__
,
chosen
:
nn
.
Linear
(
3
if
chosen
==
0
else
16
,
16
),
'second'
:
lambda
_
,
__
,
chosen
:
nn
.
Linear
(
3
if
chosen
==
0
else
16
,
16
,
bias
=
False
)
},
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
merge_op
=
'all'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
([
x
,
y
])
test/ut/retiarii/test_highlevel_apis.py
View file @
c80bda29
...
@@ -23,6 +23,10 @@ from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process
...
@@ -23,6 +23,10 @@ from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.serializer
import
model_wrapper
from
nni.retiarii.utils
import
ContextStack
,
NoContextError
,
original_state_dict_hooks
from
nni.retiarii.utils
import
ContextStack
,
NoContextError
,
original_state_dict_hooks
from
.models
import
(
CellSimple
,
CellDefaultArgs
,
CellCustomProcessor
,
CellLooseEnd
,
CellOpFactory
)
class
EnumerateSampler
(
Sampler
):
class
EnumerateSampler
(
Sampler
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -924,17 +928,7 @@ class Python(GraphIR):
...
@@ -924,17 +928,7 @@ class Python(GraphIR):
model
=
Net
()
model
=
Net
()
def
test_cell
(
self
):
def
test_cell
(
self
):
@
model_wrapper
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
CellSimple
())
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
16
,
bias
=
False
)],
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
merge_op
=
'all'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
(
x
,
y
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
sampler
=
EnumerateSampler
()
model
=
raw_model
model
=
raw_model
...
@@ -943,16 +937,7 @@ class Python(GraphIR):
...
@@ -943,16 +937,7 @@ class Python(GraphIR):
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
),
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
torch
.
randn
(
1
,
16
),
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
@
model_wrapper
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
CellDefaultArgs
())
class
Net2
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
16
,
bias
=
False
)],
num_nodes
=
4
)
def
forward
(
self
,
x
):
return
self
.
cell
(
x
)
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net2
())
for
_
in
range
(
10
):
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
sampler
=
EnumerateSampler
()
model
=
raw_model
model
=
raw_model
...
@@ -961,34 +946,7 @@ class Python(GraphIR):
...
@@ -961,34 +946,7 @@ class Python(GraphIR):
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
self
.
assertTrue
(
self
.
_get_converted_pytorch_model
(
model
)(
torch
.
randn
(
1
,
16
)).
size
()
==
torch
.
Size
([
1
,
64
]))
def
test_cell_predecessors
(
self
):
def
test_cell_predecessors
(
self
):
from
typing
import
List
,
Tuple
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
CellCustomProcessor
())
class
Preprocessor
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
3
,
16
)
def
forward
(
self
,
x
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
return
[
self
.
linear
(
x
[
0
]),
x
[
1
]]
class
Postprocessor
(
nn
.
Module
):
def
forward
(
self
,
this
:
torch
.
Tensor
,
prev
:
List
[
torch
.
Tensor
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
prev
[
-
1
],
this
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
({
'first'
:
nn
.
Linear
(
16
,
16
),
'second'
:
nn
.
Linear
(
16
,
16
,
bias
=
False
)
},
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
preprocessor
=
Preprocessor
(),
postprocessor
=
Postprocessor
(),
merge_op
=
'all'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
([
x
,
y
])
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
sampler
=
EnumerateSampler
()
model
=
raw_model
model
=
raw_model
...
@@ -1000,17 +958,7 @@ class Python(GraphIR):
...
@@ -1000,17 +958,7 @@ class Python(GraphIR):
self
.
assertTrue
(
result
[
1
].
size
()
==
torch
.
Size
([
1
,
64
]))
self
.
assertTrue
(
result
[
1
].
size
()
==
torch
.
Size
([
1
,
64
]))
def
test_cell_loose_end
(
self
):
def
test_cell_loose_end
(
self
):
@
model_wrapper
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
CellLooseEnd
())
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
16
,
bias
=
False
)],
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
merge_op
=
'loose_end'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
([
x
,
y
])
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
any_not_all
=
False
any_not_all
=
False
for
_
in
range
(
10
):
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
sampler
=
EnumerateSampler
()
...
@@ -1026,19 +974,7 @@ class Python(GraphIR):
...
@@ -1026,19 +974,7 @@ class Python(GraphIR):
self
.
assertTrue
(
any_not_all
)
self
.
assertTrue
(
any_not_all
)
def
test_cell_complex
(
self
):
def
test_cell_complex
(
self
):
@
model_wrapper
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
CellOpFactory
())
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
cell
=
nn
.
Cell
({
'first'
:
lambda
_
,
__
,
chosen
:
nn
.
Linear
(
3
if
chosen
==
0
else
16
,
16
),
'second'
:
lambda
_
,
__
,
chosen
:
nn
.
Linear
(
3
if
chosen
==
0
else
16
,
16
,
bias
=
False
)
},
num_nodes
=
4
,
num_ops_per_node
=
2
,
num_predecessors
=
2
,
merge_op
=
'all'
)
def
forward
(
self
,
x
,
y
):
return
self
.
cell
([
x
,
y
])
raw_model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
for
_
in
range
(
10
):
for
_
in
range
(
10
):
sampler
=
EnumerateSampler
()
sampler
=
EnumerateSampler
()
model
=
raw_model
model
=
raw_model
...
...
test/ut/retiarii/test_oneshot.py
View file @
c80bda29
...
@@ -137,6 +137,31 @@ class RepeatNet(nn.Module):
...
@@ -137,6 +137,31 @@ class RepeatNet(nn.Module):
return
F
.
log_softmax
(
x
,
dim
=
1
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
@
model_wrapper
class
CellNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
stem
=
nn
.
Conv2d
(
1
,
5
,
7
,
stride
=
4
)
self
.
cells
=
nn
.
Repeat
(
lambda
index
:
nn
.
Cell
({
'conv1'
:
lambda
_
,
__
,
inp
:
nn
.
Conv2d
(
(
5
if
index
==
0
else
3
*
4
)
if
inp
is
not
None
and
inp
<
1
else
4
,
4
,
1
),
'conv2'
:
lambda
_
,
__
,
inp
:
nn
.
Conv2d
(
(
5
if
index
==
0
else
3
*
4
)
if
inp
is
not
None
and
inp
<
1
else
4
,
4
,
3
,
padding
=
1
),
},
3
,
merge_op
=
'loose_end'
),
(
1
,
3
)
)
self
.
fc
=
nn
.
Linear
(
3
*
4
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
x
=
self
.
cells
(
x
)
x
=
torch
.
mean
(
x
,
(
2
,
3
))
x
=
self
.
fc
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
@
basic_unit
@
basic_unit
class
MyOp
(
nn
.
Module
):
class
MyOp
(
nn
.
Module
):
def
__init__
(
self
,
some_ch
):
def
__init__
(
self
,
some_ch
):
...
@@ -183,6 +208,8 @@ def _mnist_net(type_, evaluator_kwargs):
...
@@ -183,6 +208,8 @@ def _mnist_net(type_, evaluator_kwargs):
base_model
=
ValueChoiceConvNet
()
base_model
=
ValueChoiceConvNet
()
elif
type_
==
'repeat'
:
elif
type_
==
'repeat'
:
base_model
=
RepeatNet
()
base_model
=
RepeatNet
()
elif
type_
==
'cell'
:
base_model
=
CellNet
()
elif
type_
==
'custom_op'
:
elif
type_
==
'custom_op'
:
base_model
=
CustomOpValueChoiceNet
()
base_model
=
CustomOpValueChoiceNet
()
else
:
else
:
...
@@ -246,7 +273,7 @@ def _test_strategy(strategy_, support_value_choice=True, multi_gpu=False):
...
@@ -246,7 +273,7 @@ def _test_strategy(strategy_, support_value_choice=True, multi_gpu=False):
(
_mnist_net
(
'simple'
,
evaluator_kwargs
),
True
),
(
_mnist_net
(
'simple'
,
evaluator_kwargs
),
True
),
(
_mnist_net
(
'simple_value_choice'
,
evaluator_kwargs
),
support_value_choice
),
(
_mnist_net
(
'simple_value_choice'
,
evaluator_kwargs
),
support_value_choice
),
(
_mnist_net
(
'value_choice'
,
evaluator_kwargs
),
support_value_choice
),
(
_mnist_net
(
'value_choice'
,
evaluator_kwargs
),
support_value_choice
),
(
_mnist_net
(
'repeat'
,
evaluator_kwargs
),
Fals
e
),
# no strategy supports repeat currently
(
_mnist_net
(
'repeat'
,
evaluator_kwargs
),
support_value_choic
e
),
# no strategy supports repeat currently
(
_mnist_net
(
'custom_op'
,
evaluator_kwargs
),
False
),
# this is definitely a NO
(
_mnist_net
(
'custom_op'
,
evaluator_kwargs
),
False
),
# this is definitely a NO
(
_multihead_attention_net
(
evaluator_kwargs
),
support_value_choice
),
(
_multihead_attention_net
(
evaluator_kwargs
),
support_value_choice
),
]
]
...
...
test/ut/retiarii/test_oneshot_supermodules.py
View file @
c80bda29
...
@@ -4,17 +4,23 @@ import numpy as np
...
@@ -4,17 +4,23 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
Conv2d
,
BatchNorm2d
,
Linear
,
MultiheadAttention
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
Conv2d
,
BatchNorm2d
,
Linear
,
MultiheadAttention
from
nni.retiarii.oneshot.pytorch.base_lightning
import
traverse_and_mutate_submodules
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
(
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
(
MixedOpDifferentiablePolicy
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
GumbelSoftmax
MixedOpDifferentiablePolicy
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
GumbelSoftmax
,
DifferentiableMixedRepeat
,
DifferentiableMixedCell
)
)
from
nni.retiarii.oneshot.pytorch.supermodule.sampling
import
(
from
nni.retiarii.oneshot.pytorch.supermodule.sampling
import
(
MixedOpPathSamplingPolicy
,
PathSamplingLayer
,
PathSamplingInput
MixedOpPathSamplingPolicy
,
PathSamplingLayer
,
PathSamplingInput
,
PathSamplingRepeat
,
PathSamplingCell
)
)
from
nni.retiarii.oneshot.pytorch.supermodule.operation
import
MixedConv2d
,
NATIVE_MIXED_OPERATIONS
from
nni.retiarii.oneshot.pytorch.supermodule.operation
import
MixedConv2d
,
NATIVE_MIXED_OPERATIONS
from
nni.retiarii.oneshot.pytorch.supermodule.proxyless
import
ProxylessMixedLayer
,
ProxylessMixedInput
from
nni.retiarii.oneshot.pytorch.supermodule.proxyless
import
ProxylessMixedLayer
,
ProxylessMixedInput
from
nni.retiarii.oneshot.pytorch.supermodule._operation_utils
import
Slicable
as
S
,
MaybeWeighted
as
W
from
nni.retiarii.oneshot.pytorch.supermodule._operation_utils
import
Slicable
as
S
,
MaybeWeighted
as
W
from
nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils
import
*
from
nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils
import
*
from
.models
import
(
CellSimple
,
CellDefaultArgs
,
CellCustomProcessor
,
CellLooseEnd
,
CellOpFactory
)
def
test_slice
():
def
test_slice
():
weight
=
np
.
ones
((
3
,
7
,
24
,
23
))
weight
=
np
.
ones
((
3
,
7
,
24
,
23
))
...
@@ -246,3 +252,113 @@ def test_proxyless_layer_input():
...
@@ -246,3 +252,113 @@ def test_proxyless_layer_input():
assert
input
.
resample
({})[
'ddd'
]
in
list
(
range
(
5
))
assert
input
.
resample
({})[
'ddd'
]
in
list
(
range
(
5
))
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
()
==
torch
.
Size
([
4
,
2
])
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
()
==
torch
.
Size
([
4
,
2
])
assert
input
.
export
({})[
'ddd'
]
in
list
(
range
(
5
))
assert
input
.
export
({})[
'ddd'
]
in
list
(
range
(
5
))
def
test_pathsampling_repeat
():
op
=
PathSamplingRepeat
([
nn
.
Linear
(
16
,
16
),
nn
.
Linear
(
16
,
8
),
nn
.
Linear
(
8
,
4
)],
ValueChoice
([
1
,
2
,
3
],
label
=
'ccc'
))
sample
=
op
.
resample
({})
assert
sample
[
'ccc'
]
in
[
1
,
2
,
3
]
for
i
in
range
(
1
,
4
):
op
.
resample
({
'ccc'
:
i
})
out
=
op
(
torch
.
randn
(
2
,
16
))
assert
out
.
shape
[
1
]
==
[
16
,
8
,
4
][
i
-
1
]
op
=
PathSamplingRepeat
([
nn
.
Linear
(
i
+
1
,
i
+
2
)
for
i
in
range
(
7
)],
2
*
ValueChoice
([
1
,
2
,
3
],
label
=
'ddd'
)
+
1
)
sample
=
op
.
resample
({})
assert
sample
[
'ddd'
]
in
[
1
,
2
,
3
]
for
i
in
range
(
1
,
4
):
op
.
resample
({
'ddd'
:
i
})
out
=
op
(
torch
.
randn
(
2
,
1
))
assert
out
.
shape
[
1
]
==
(
2
*
i
+
1
)
+
1
def
test_differentiable_repeat
():
op
=
DifferentiableMixedRepeat
(
[
nn
.
Linear
(
8
if
i
==
0
else
16
,
16
)
for
i
in
range
(
4
)],
ValueChoice
([
0
,
1
],
label
=
'ccc'
)
*
2
+
1
,
GumbelSoftmax
(
-
1
),
{}
)
op
.
resample
({})
assert
op
(
torch
.
randn
(
2
,
8
)).
size
()
==
torch
.
Size
([
2
,
16
])
sample
=
op
.
export
({})
assert
'ccc'
in
sample
and
sample
[
'ccc'
]
in
[
0
,
1
]
def
test_pathsampling_cell
():
for
cell_cls
in
[
CellSimple
,
CellDefaultArgs
,
CellCustomProcessor
,
CellLooseEnd
,
CellOpFactory
]:
model
=
cell_cls
()
nas_modules
=
traverse_and_mutate_submodules
(
model
,
[
PathSamplingLayer
.
mutate
,
PathSamplingInput
.
mutate
,
PathSamplingCell
.
mutate
,
],
{})
result
=
{}
for
module
in
nas_modules
:
result
.
update
(
module
.
resample
(
memo
=
result
))
assert
len
(
result
)
==
model
.
cell
.
num_nodes
*
model
.
cell
.
num_ops_per_node
*
2
result
=
{}
for
module
in
nas_modules
:
result
.
update
(
module
.
export
(
memo
=
result
))
assert
len
(
result
)
==
model
.
cell
.
num_nodes
*
model
.
cell
.
num_ops_per_node
*
2
if
cell_cls
in
[
CellLooseEnd
,
CellOpFactory
]:
assert
isinstance
(
model
.
cell
,
PathSamplingCell
)
else
:
assert
not
isinstance
(
model
.
cell
,
PathSamplingCell
)
inputs
=
{
CellSimple
:
(
torch
.
randn
(
2
,
16
),
torch
.
randn
(
2
,
16
)),
CellDefaultArgs
:
(
torch
.
randn
(
2
,
16
),),
CellCustomProcessor
:
(
torch
.
randn
(
2
,
3
),
torch
.
randn
(
2
,
16
)),
CellLooseEnd
:
(
torch
.
randn
(
2
,
16
),
torch
.
randn
(
2
,
16
)),
CellOpFactory
:
(
torch
.
randn
(
2
,
3
),
torch
.
randn
(
2
,
16
)),
}[
cell_cls
]
output
=
model
(
*
inputs
)
if
cell_cls
==
CellCustomProcessor
:
assert
isinstance
(
output
,
tuple
)
and
len
(
output
)
==
2
and
\
output
[
1
].
shape
==
torch
.
Size
([
2
,
16
*
model
.
cell
.
num_nodes
])
else
:
# no loose-end support for now
assert
output
.
shape
==
torch
.
Size
([
2
,
16
*
model
.
cell
.
num_nodes
])
def
test_differentiable_cell
():
for
cell_cls
in
[
CellSimple
,
CellDefaultArgs
,
CellCustomProcessor
,
CellLooseEnd
,
CellOpFactory
]:
model
=
cell_cls
()
nas_modules
=
traverse_and_mutate_submodules
(
model
,
[
DifferentiableMixedLayer
.
mutate
,
DifferentiableMixedInput
.
mutate
,
DifferentiableMixedCell
.
mutate
,
],
{})
result
=
{}
for
module
in
nas_modules
:
result
.
update
(
module
.
export
(
memo
=
result
))
assert
len
(
result
)
==
model
.
cell
.
num_nodes
*
model
.
cell
.
num_ops_per_node
*
2
ctrl_params
=
[]
for
m
in
nas_modules
:
ctrl_params
+=
list
(
m
.
parameters
(
arch
=
True
))
if
cell_cls
in
[
CellLooseEnd
,
CellOpFactory
]:
assert
len
(
ctrl_params
)
==
model
.
cell
.
num_nodes
*
(
model
.
cell
.
num_nodes
+
3
)
//
2
assert
isinstance
(
model
.
cell
,
DifferentiableMixedCell
)
else
:
assert
not
isinstance
(
model
.
cell
,
DifferentiableMixedCell
)
inputs
=
{
CellSimple
:
(
torch
.
randn
(
2
,
16
),
torch
.
randn
(
2
,
16
)),
CellDefaultArgs
:
(
torch
.
randn
(
2
,
16
),),
CellCustomProcessor
:
(
torch
.
randn
(
2
,
3
),
torch
.
randn
(
2
,
16
)),
CellLooseEnd
:
(
torch
.
randn
(
2
,
16
),
torch
.
randn
(
2
,
16
)),
CellOpFactory
:
(
torch
.
randn
(
2
,
3
),
torch
.
randn
(
2
,
16
)),
}[
cell_cls
]
output
=
model
(
*
inputs
)
if
cell_cls
==
CellCustomProcessor
:
assert
isinstance
(
output
,
tuple
)
and
len
(
output
)
==
2
and
\
output
[
1
].
shape
==
torch
.
Size
([
2
,
16
*
model
.
cell
.
num_nodes
])
else
:
# no loose-end support for now
assert
output
.
shape
==
torch
.
Size
([
2
,
16
*
model
.
cell
.
num_nodes
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment