Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
2815fb1f
Unverified
Commit
2815fb1f
authored
Jun 13, 2022
by
Yuge Zhang
Committed by
GitHub
Jun 13, 2022
Browse files
Make models in search space hub work with one-shot (#4921)
parent
80beca52
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
696 additions
and
69 deletions
+696
-69
nni/retiarii/hub/pytorch/nasnet.py
nni/retiarii/hub/pytorch/nasnet.py
+110
-1
nni/retiarii/hub/pytorch/proxylessnas.py
nni/retiarii/hub/pytorch/proxylessnas.py
+4
-4
nni/retiarii/oneshot/pytorch/__init__.py
nni/retiarii/oneshot/pytorch/__init__.py
+0
-1
nni/retiarii/oneshot/pytorch/differentiable.py
nni/retiarii/oneshot/pytorch/differentiable.py
+2
-2
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
+1
-1
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
...etiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
+106
-2
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
+47
-17
nni/retiarii/oneshot/pytorch/supermodule/operation.py
nni/retiarii/oneshot/pytorch/supermodule/operation.py
+50
-8
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
+37
-27
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+1
-1
test/ut/retiarii/test_oneshot_supermodules.py
test/ut/retiarii/test_oneshot_supermodules.py
+91
-2
test/ut/retiarii/test_space_hub.py
test/ut/retiarii/test_space_hub.py
+1
-1
test/ut/retiarii/test_space_hub_oneshot.py
test/ut/retiarii/test_space_hub_oneshot.py
+242
-0
test/vso_tools/pack_testdata.py
test/vso_tools/pack_testdata.py
+4
-2
No files found.
nni/retiarii/hub/pytorch/nasnet.py
View file @
2815fb1f
...
...
@@ -21,6 +21,9 @@ import torch
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
model_wrapper
from
nni.retiarii.oneshot.pytorch.supermodule.sampling
import
PathSamplingRepeat
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
DifferentiableMixedRepeat
from
.utils.fixed
import
FixedFactory
from
.utils.pretrained
import
load_pretrained_weight
...
...
@@ -348,6 +351,100 @@ class CellBuilder:
return
cell
class
NDSStage
(
nn
.
Repeat
):
"""This class defines NDSStage, a special type of Repeat, for isinstance check, and shape alignment.
In NDS, we can't simply use Repeat to stack the blocks,
because the output shape of each stacked block can be different.
This is a problem for one-shot strategy because they assume every possible candidate
should return values of the same shape.
Therefore, we need :class:`NDSStagePathSampling` and :class:`NDSStageDifferentiable`
to manually align the shapes -- specifically, to transform the first block in each stage.
This is not required though, when depth is not changing, or the mutable depth causes no problem
(e.g., when the minimum depth is large enough).
.. attention::
Assumption: Loose end is treated as all in ``merge_op`` (the case in one-shot),
which enforces reduction cell and normal cells in the same stage to have the exact same output shape.
"""
estimated_out_channels_prev
:
int
"""Output channels of cells in last stage."""
estimated_out_channels
:
int
"""Output channels of this stage. It's **estimated** because it assumes ``all`` as ``merge_op``."""
downsampling
:
bool
"""This stage has downsampling"""
def
first_cell_transformation_factory
(
self
)
->
Optional
[
nn
.
Module
]:
"""To make the "previous cell" in first cell's output have the same shape as cells in this stage."""
if
self
.
downsampling
:
return
FactorizedReduce
(
self
.
estimated_out_channels_prev
,
self
.
estimated_out_channels
)
elif
self
.
estimated_out_channels_prev
is
not
self
.
estimated_out_channels
:
# Can't use != here, ValueChoice doesn't support
return
ReLUConvBN
(
self
.
estimated_out_channels_prev
,
self
.
estimated_out_channels
,
1
,
1
,
0
)
return
None
class
NDSStagePathSampling
(
PathSamplingRepeat
):
"""The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating."""
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
NDSStage
)
and
isinstance
(
module
.
depth_choice
,
nn
.
api
.
ValueChoiceX
):
return
cls
(
module
.
first_cell_transformation_factory
(),
cast
(
List
[
nn
.
Module
],
module
.
blocks
),
module
.
depth_choice
)
def
__init__
(
self
,
first_cell_transformation
:
Optional
[
nn
.
Module
],
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
first_cell_transformation
=
first_cell_transformation
def
reduction
(
self
,
items
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
sampled
:
List
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
1
not
in
sampled
or
self
.
first_cell_transformation
is
None
:
return
super
().
reduction
(
items
,
sampled
)
# items[0] must be the result of first cell
assert
len
(
items
[
0
])
==
2
# Only apply the transformation on "prev" output.
items
[
0
]
=
(
self
.
first_cell_transformation
(
items
[
0
][
0
]),
items
[
0
][
1
])
return
super
().
reduction
(
items
,
sampled
)
class
NDSStageDifferentiable
(
DifferentiableMixedRepeat
):
"""The differentiable implementation (for one-shot) of each NDS stage if depth is mutating."""
@
classmethod
def
mutate
(
cls
,
module
,
name
,
memo
,
mutate_kwargs
):
if
isinstance
(
module
,
NDSStage
)
and
isinstance
(
module
.
depth_choice
,
nn
.
api
.
ValueChoiceX
):
# Only interesting when depth is mutable
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
module
.
first_cell_transformation_factory
(),
cast
(
List
[
nn
.
Module
],
module
.
blocks
),
module
.
depth_choice
,
softmax
,
memo
)
def
__init__
(
self
,
first_cell_transformation
:
Optional
[
nn
.
Module
],
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
first_cell_transformation
=
first_cell_transformation
def
reduction
(
self
,
items
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
weights
:
List
[
float
],
depths
:
List
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
1
not
in
depths
or
self
.
first_cell_transformation
is
None
:
return
super
().
reduction
(
items
,
weights
,
depths
)
# Same as NDSStagePathSampling
assert
len
(
items
[
0
])
==
2
items
[
0
]
=
(
self
.
first_cell_transformation
(
items
[
0
][
0
]),
items
[
0
][
1
])
return
super
().
reduction
(
items
,
weights
,
depths
)
_INIT_PARAMETER_DOCS
=
"""
Parameters
...
...
@@ -437,6 +534,8 @@ class NDS(nn.Module):
C_pprev
=
C_prev
=
3
*
C
C_curr
=
C
last_cell_reduce
=
False
else
:
raise
ValueError
(
f
'Unsupported dataset:
{
dataset
}
'
)
self
.
stages
=
nn
.
ModuleList
()
for
stage_idx
in
range
(
3
):
...
...
@@ -448,9 +547,19 @@ class NDS(nn.Module):
# C_out is usually `C * num_nodes_per_cell` because of concat operator.
cell_builder
=
CellBuilder
(
op_candidates
,
C_pprev
,
C_prev
,
C_curr
,
num_nodes_per_cell
,
merge_op
,
stage_idx
>
0
,
last_cell_reduce
)
stage
=
nn
.
Repeat
(
cell_builder
,
num_cells_per_stage
[
stage_idx
])
stage
:
Union
[
NDSStage
,
nn
.
Sequential
]
=
NDSStage
(
cell_builder
,
num_cells_per_stage
[
stage_idx
])
if
isinstance
(
stage
,
NDSStage
):
stage
.
estimated_out_channels_prev
=
cast
(
int
,
C_prev
)
stage
.
estimated_out_channels
=
cast
(
int
,
C_curr
*
num_nodes_per_cell
)
stage
.
downsampling
=
stage_idx
>
0
self
.
stages
.
append
(
stage
)
# NOTE: output_node_indices will be computed on-the-fly in trial code.
# When constructing model space, it's just all the nodes in the cell,
# which happens to be the case of one-shot supernet.
# C_pprev is output channel number of last second cell among all the cells already built.
if
len
(
stage
)
>
1
:
# Contains more than one cell
...
...
nni/retiarii/hub/pytorch/proxylessnas.py
View file @
2815fb1f
...
...
@@ -98,7 +98,6 @@ class ConvBNReLU(nn.Sequential):
]
super
().
__init__
(
*
simplify_sequential
(
blocks
))
self
.
out_channels
=
out_channels
class
DepthwiseSeparableConv
(
nn
.
Sequential
):
...
...
@@ -133,7 +132,8 @@ class DepthwiseSeparableConv(nn.Sequential):
ConvBNReLU
(
in_channels
,
out_channels
,
kernel_size
=
1
,
norm_layer
=
norm_layer
,
activation_layer
=
nn
.
Identity
)
]
super
().
__init__
(
*
simplify_sequential
(
blocks
))
self
.
has_skip
=
stride
==
1
and
in_channels
==
out_channels
# NOTE: "is" is used here instead of "==" to avoid creating a new value choice.
self
.
has_skip
=
stride
==
1
and
in_channels
is
out_channels
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
has_skip
:
...
...
@@ -177,8 +177,8 @@ class InvertedResidual(nn.Sequential):
hidden_ch
=
cast
(
int
,
make_divisible
(
in_channels
*
expand_ratio
,
8
))
# NOTE: this equivalence check
should also
work for ValueChoice
self
.
has_skip
=
stride
==
1
and
in_channels
==
out_channels
# NOTE: this equivalence check
(==) does NOT
work for ValueChoice
, need to use "is"
self
.
has_skip
=
stride
==
1
and
in_channels
is
out_channels
layers
:
List
[
nn
.
Module
]
=
[
# point-wise convolution
...
...
nni/retiarii/oneshot/pytorch/__init__.py
View file @
2815fb1f
...
...
@@ -7,4 +7,3 @@ from .proxyless import ProxylessTrainer
from
.random
import
SinglePathTrainer
,
RandomTrainer
from
.differentiable
import
DartsLightningModule
,
ProxylessLightningModule
,
GumbelDartsLightningModule
from
.sampling
import
EnasLightningModule
,
RandomSamplingLightningModule
from
.utils
import
InterleavedTrainValDataLoader
,
ConcatenateTrainValDataLoader
nni/retiarii/oneshot/pytorch/differentiable.py
View file @
2815fb1f
...
...
@@ -60,7 +60,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
)
__doc__
=
_darts_note
.
format
(
module_notes
=
'The DARTS Module should be trained with :class:`
nni.retiarii.oneshot.utils.InterleavedTrainValData
Loader`.'
,
module_notes
=
'The DARTS Module should be trained with :class:`
pytorch_lightning.trainer.supporters.Combined
Loader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
...
...
@@ -161,7 +161,7 @@ class ProxylessLightningModule(DartsLightningModule):
"""
.
format
(
base_params
=
BaseOneShotLightningModule
.
_mutation_hooks_note
)
__doc__
=
_proxyless_note
.
format
(
module_notes
=
'This module should be trained with :class:`
nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValData
Loader`.'
,
module_notes
=
'This module should be trained with :class:`
pytorch_lightning.trainer.supporters.Combined
Loader`.'
,
module_params
=
BaseOneShotLightningModule
.
_inner_module_note
,
)
...
...
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
View file @
2815fb1f
...
...
@@ -115,7 +115,7 @@ def _slice_weight(weight: T, slice_: multidim_slice | list[tuple[multidim_slice,
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr
=
np
.
zeros
(
weight
.
shape
,
dtype
=
np
.
bool
)
# type: ignore
dummy_arr
=
np
.
zeros
(
weight
.
shape
,
dtype
=
bool
)
# type: ignore
no_effect
=
cast
(
Any
,
_do_slice
(
dummy_arr
,
slice_
)).
shape
==
dummy_arr
.
shape
if
no_effect
:
...
...
nni/retiarii/oneshot/pytorch/supermodule/_valuechoice_utils.py
View file @
2815fb1f
...
...
@@ -7,7 +7,7 @@ in the way that is most convenient to one-shot algorithms."""
from
__future__
import
annotations
import
itertools
from
typing
import
Any
,
TypeVar
,
List
,
cast
from
typing
import
Any
,
TypeVar
,
List
,
cast
,
Mapping
,
Sequence
,
Optional
,
Iterable
import
numpy
as
np
import
torch
...
...
@@ -20,7 +20,13 @@ Choice = Any
T
=
TypeVar
(
'T'
)
__all__
=
[
'dedup_inner_choices'
,
'evaluate_value_choice_with_dict'
,
'traverse_all_options'
]
__all__
=
[
'dedup_inner_choices'
,
'evaluate_value_choice_with_dict'
,
'traverse_all_options'
,
'weighted_sum'
,
'evaluate_constant'
,
]
def
dedup_inner_choices
(
value_choices
:
list
[
ValueChoiceX
])
->
dict
[
str
,
ParameterSpec
]:
...
...
@@ -138,3 +144,101 @@ def traverse_all_options(
return
sorted
(
result
.
keys
())
# type: ignore
else
:
return
sorted
(
result
.
items
())
# type: ignore
def
evaluate_constant
(
expr
:
Any
)
->
Any
:
"""Evaluate a value choice expression to a constant. Raise ValueError if it's not a constant."""
all_options
=
traverse_all_options
(
expr
)
if
len
(
all_options
)
>
1
:
raise
ValueError
(
f
'
{
expr
}
is not evaluated to a constant. All possible values are:
{
all_options
}
'
)
res
=
all_options
[
0
]
return
res
def
weighted_sum
(
items
:
list
[
T
],
weights
:
Sequence
[
float
|
None
]
=
cast
(
Sequence
[
Optional
[
float
]],
None
))
->
T
:
"""Return a weighted sum of items.
Items can be list of tensors, numpy arrays, or nested lists / dicts.
If ``weights`` is None, this is simply an unweighted sum.
"""
if
weights
is
None
:
weights
=
[
None
]
*
len
(
items
)
assert
len
(
items
)
==
len
(
weights
)
>
0
elem
=
items
[
0
]
unsupported_msg
=
f
'Unsupported element type in weighted sum:
{
type
(
elem
)
}
. Value is:
{
elem
}
'
if
isinstance
(
elem
,
str
):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise
TypeError
(
unsupported_msg
)
try
:
if
isinstance
(
elem
,
(
torch
.
Tensor
,
np
.
ndarray
,
float
,
int
,
np
.
number
)):
if
weights
[
0
]
is
None
:
res
=
elem
else
:
res
=
elem
*
weights
[
0
]
for
it
,
weight
in
zip
(
items
[
1
:],
weights
[
1
:]):
if
type
(
it
)
!=
type
(
elem
):
raise
TypeError
(
f
'Expect type
{
type
(
elem
)
}
but found
{
type
(
it
)
}
. Can not be summed'
)
if
weight
is
None
:
res
=
res
+
it
# type: ignore
else
:
res
=
res
+
it
*
weight
# type: ignore
return
cast
(
T
,
res
)
if
isinstance
(
elem
,
Mapping
):
for
item
in
items
:
if
not
isinstance
(
item
,
Mapping
):
raise
TypeError
(
f
'Expect type
{
type
(
elem
)
}
but found
{
type
(
item
)
}
'
)
if
set
(
item
)
!=
set
(
elem
):
raise
KeyError
(
f
'Expect keys
{
list
(
elem
)
}
but found
{
list
(
item
)
}
'
)
return
cast
(
T
,
{
key
:
weighted_sum
(
cast
(
List
[
dict
],
[
cast
(
Mapping
,
d
)[
key
]
for
d
in
items
]),
weights
)
for
key
in
elem
})
if
isinstance
(
elem
,
Sequence
):
for
item
in
items
:
if
not
isinstance
(
item
,
Sequence
):
raise
TypeError
(
f
'Expect type
{
type
(
elem
)
}
but found
{
type
(
item
)
}
'
)
if
len
(
item
)
!=
len
(
elem
):
raise
ValueError
(
f
'Expect length
{
len
(
item
)
}
but found
{
len
(
elem
)
}
'
)
transposed
=
cast
(
Iterable
[
list
],
zip
(
*
items
))
# type: ignore
return
cast
(
T
,
[
weighted_sum
(
column
,
weights
)
for
column
in
transposed
])
except
(
TypeError
,
ValueError
,
RuntimeError
,
KeyError
):
raise
ValueError
(
'Error when summing items. Value format / shape does not match. See full traceback for details.'
+
''
.
join
([
f
'
\n
{
idx
}
:
{
_summarize_elem_format
(
it
)
}
'
for
idx
,
it
in
enumerate
(
items
)
])
)
# Dealing with all unexpected types.
raise
TypeError
(
unsupported_msg
)
def
_summarize_elem_format
(
elem
:
Any
)
->
Any
:
# Get a summary of one elem
# Helps generate human-readable error messages
class
_repr_object
:
# empty object is only repr
def
__init__
(
self
,
representation
):
self
.
representation
=
representation
def
__repr__
(
self
):
return
self
.
representation
if
isinstance
(
elem
,
torch
.
Tensor
):
return
_repr_object
(
'torch.Tensor('
+
', '
.
join
(
map
(
str
,
elem
.
shape
))
+
')'
)
if
isinstance
(
elem
,
np
.
ndarray
):
return
_repr_object
(
'np.array('
+
', '
.
join
(
map
(
str
,
elem
.
shape
))
+
')'
)
if
isinstance
(
elem
,
Mapping
):
return
{
key
:
_summarize_elem_format
(
value
)
for
key
,
value
in
elem
.
items
()}
if
isinstance
(
elem
,
Sequence
):
return
[
_summarize_elem_format
(
value
)
for
value
in
elem
]
# fallback to original, for cases like float, int, ...
return
elem
nni/retiarii/oneshot/pytorch/supermodule/differentiable.py
View file @
2815fb1f
...
...
@@ -21,14 +21,14 @@ from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs
from
.base
import
BaseSuperNetModule
from
.operation
import
MixedOperation
,
MixedOperationSamplingPolicy
from
.sampling
import
PathSamplingCell
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
,
weighted_sum
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'DifferentiableMixedLayer'
,
'DifferentiableMixedInput'
,
'DifferentiableMixedRepeat'
,
'DifferentiableMixedCell'
,
'MixedOpDifferentiablePolicy'
'MixedOpDifferentiablePolicy'
,
]
...
...
@@ -77,7 +77,11 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
_arch_parameter_names
:
list
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
paths
:
list
[
tuple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
paths
:
list
[
tuple
[
str
,
nn
.
Module
]],
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
self
.
op_names
=
[]
if
len
(
alpha
)
!=
len
(
paths
):
...
...
@@ -118,11 +122,15 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
list
(
module
.
named_children
()),
alpha
,
softmax
,
module
.
label
)
def
reduction
(
self
,
items
:
list
[
Any
],
weights
:
list
[
float
])
->
Any
:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return
weighted_sum
(
items
,
weights
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
op_results
=
torch
.
stack
([
getattr
(
self
,
op
)(
*
args
,
**
kwargs
)
for
op
in
self
.
op_names
])
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
op_results
.
size
())
-
1
)
return
torch
.
sum
(
op_results
*
self
.
_softmax
(
self
.
_arch_alpha
).
view
(
*
alpha_shape
),
0
)
all_op_results
=
[
getattr
(
self
,
op
)(
*
args
,
**
kwargs
)
for
op
in
self
.
op_names
]
return
self
.
reduction
(
all_op_results
,
self
.
_softmax
(
self
.
_arch_alpha
))
def
parameters
(
self
,
*
args
,
**
kwargs
):
"""Parameters excluding architecture parameters."""
...
...
@@ -167,7 +175,12 @@ class DifferentiableMixedInput(BaseSuperNetModule):
_arch_parameter_names
:
list
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
|
None
,
alpha
:
torch
.
Tensor
,
softmax
:
nn
.
Module
,
label
:
str
):
super
().
__init__
()
self
.
n_candidates
=
n_candidates
if
len
(
alpha
)
!=
n_candidates
:
...
...
@@ -217,11 +230,14 @@ class DifferentiableMixedInput(BaseSuperNetModule):
softmax
=
mutate_kwargs
.
get
(
'softmax'
,
nn
.
Softmax
(
-
1
))
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
alpha
,
softmax
,
module
.
label
)
def
reduction
(
self
,
items
:
list
[
Any
],
weights
:
list
[
float
])
->
Any
:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return
weighted_sum
(
items
,
weights
)
def
forward
(
self
,
inputs
):
"""Forward takes a list of input candidates."""
inputs
=
torch
.
stack
(
inputs
)
alpha_shape
=
[
-
1
]
+
[
1
]
*
(
len
(
inputs
.
size
())
-
1
)
return
torch
.
sum
(
inputs
*
self
.
_softmax
(
self
.
_arch_alpha
).
view
(
*
alpha_shape
),
0
)
return
self
.
reduction
(
inputs
,
self
.
_softmax
(
self
.
_arch_alpha
))
def
parameters
(
self
,
*
args
,
**
kwargs
):
"""Parameters excluding architecture parameters."""
...
...
@@ -318,11 +334,18 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
If the output is not a single tensor, it will be summed at every independant dimension.
See :func:`weighted_sum` for details.
"""
_arch_parameter_names
:
list
[
str
]
=
[
'_arch_alpha'
]
def
__init__
(
self
,
blocks
:
list
[
nn
.
Module
],
depth
:
ChoiceOf
[
int
],
softmax
:
nn
.
Module
,
memo
:
dict
[
str
,
Any
]):
def
__init__
(
self
,
blocks
:
list
[
nn
.
Module
],
depth
:
ChoiceOf
[
int
],
softmax
:
nn
.
Module
,
memo
:
dict
[
str
,
Any
]):
super
().
__init__
()
self
.
blocks
=
blocks
self
.
depth
=
depth
...
...
@@ -377,21 +400,28 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
if
not
arch
:
yield
name
,
p
def
reduction
(
self
,
items
:
list
[
Any
],
weights
:
list
[
float
],
depths
:
list
[
int
])
->
Any
:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return
weighted_sum
(
items
,
weights
)
def
forward
(
self
,
x
):
weights
:
dict
[
str
,
torch
.
Tensor
]
=
{
label
:
self
.
_softmax
(
alpha
)
for
label
,
alpha
in
self
.
_arch_alpha
.
items
()
}
depth_weights
=
dict
(
cast
(
List
[
Tuple
[
int
,
float
]],
traverse_all_options
(
self
.
depth
,
weights
=
weights
)))
res
:
torch
.
Tensor
|
None
=
None
res
:
list
[
torch
.
Tensor
]
=
[]
weight_list
:
list
[
float
]
=
[]
depths
:
list
[
int
]
=
[]
for
i
,
block
in
enumerate
(
self
.
blocks
,
start
=
1
):
# start=1 because depths are 1, 2, 3, 4...
x
=
block
(
x
)
if
i
in
depth_weights
:
if
res
is
None
:
res
=
depth_weights
[
i
]
*
x
else
:
res
=
res
+
depth_weights
[
i
]
*
x
return
res
weight_list
.
append
(
depth_weights
[
i
])
res
.
append
(
x
)
depths
.
append
(
i
)
return
self
.
reduction
(
res
,
weight_list
,
depths
)
class
DifferentiableMixedCell
(
PathSamplingCell
):
...
...
nni/retiarii/oneshot/pytorch/supermodule/operation.py
View file @
2815fb1f
...
...
@@ -10,7 +10,8 @@ from __future__ import annotations
import
inspect
import
itertools
from
typing
import
Any
,
Type
,
TypeVar
,
cast
,
Union
,
Tuple
import
warnings
from
typing
import
Any
,
Type
,
TypeVar
,
cast
,
Union
,
Tuple
,
List
import
torch
import
torch.nn
as
nn
...
...
@@ -23,7 +24,7 @@ from nni.common.serializer import is_traceable
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
.base
import
BaseSuperNetModule
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
,
evaluate_constant
from
._operation_utils
import
Slicable
as
_S
,
MaybeWeighted
as
_W
,
int_or_int_dict
,
scalar_or_scalar_dict
T
=
TypeVar
(
'T'
)
...
...
@@ -268,14 +269,18 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
- ``in_channels``
- ``out_channels``
- ``groups``
(only supported in path sampling)
- ``groups``
- ``stride`` (only supported in path sampling)
- ``kernel_size``
- ``padding``
(only supported in path sampling)
- ``padding``
- ``dilation`` (only supported in path sampling)
``padding`` will be the "max" padding in differentiable mode.
Mutable ``groups`` is NOT supported in most cases of differentiable mode.
However, we do support one special case when the group number is proportional to ``in_channels`` and ``out_channels``.
This is often the case of depth-wise convolutions.
For channels, prefix will be sliced.
For kernels, we take the small kernel from the center and round it to floor (left top). For example ::
...
...
@@ -315,6 +320,18 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
return
max
(
all_sizes
)
elif
name
==
'groups'
:
if
'in_channels'
in
self
.
mutable_arguments
:
# If the ratio is constant, we don't need to try the maximum groups.
try
:
constant
=
evaluate_constant
(
self
.
mutable_arguments
[
'in_channels'
]
/
value_choice
)
return
max
(
cast
(
List
[
float
],
traverse_all_options
(
value_choice
)))
//
int
(
constant
)
except
ValueError
:
warnings
.
warn
(
'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. '
'This can be problematic for most one-shot algorithms. Please check whether this is your intention.'
,
RuntimeWarning
)
# minimum groups, maximum kernel
return
min
(
traverse_all_options
(
value_choice
))
...
...
@@ -328,11 +345,11 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
stride
:
_int_or_tuple
,
padding
:
scalar_or_scalar_dict
[
_int_or_tuple
],
dilation
:
int
,
groups
:
int
,
groups
:
int
_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
,
groups
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'stride, dilation
and groups
'
,
'Conv2d'
))
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'stride, dilation'
,
'Conv2d'
))
in_channels_
=
_W
(
in_channels
)
out_channels_
=
_W
(
out_channels
)
...
...
@@ -340,7 +357,32 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
# slice prefix
# For groups > 1, we use groups to slice input weights
weight
=
_S
(
self
.
weight
)[:
out_channels_
]
weight
=
_S
(
weight
)[:,
:
in_channels_
//
groups
]
if
not
isinstance
(
groups
,
dict
):
weight
=
_S
(
weight
)[:,
:
in_channels_
//
groups
]
else
:
assert
'groups'
in
self
.
mutable_arguments
err_message
=
'For differentiable one-shot strategy, when groups is a ValueChoice, '
\
'in_channels and out_channels should also be a ValueChoice. '
\
'Also, the ratios of in_channels divided by groups, and out_channels divided by groups '
\
'should be constants.'
if
'in_channels'
not
in
self
.
mutable_arguments
or
'out_channels'
not
in
self
.
mutable_arguments
:
raise
ValueError
(
err_message
)
try
:
in_channels_per_group
=
evaluate_constant
(
self
.
mutable_arguments
[
'in_channels'
]
/
self
.
mutable_arguments
[
'groups'
])
except
ValueError
:
raise
ValueError
(
err_message
)
if
in_channels_per_group
!=
int
(
in_channels_per_group
):
raise
ValueError
(
f
'Input channels per group is found to be a non-integer:
{
in_channels_per_group
}
'
)
if
inputs
.
size
(
1
)
%
in_channels_per_group
!=
0
:
raise
RuntimeError
(
f
'Input channels must be divisible by in_channels_per_group, but the input shape is
{
inputs
.
size
()
}
, '
f
'while in_channels_per_group =
{
in_channels_per_group
}
'
)
# Compute sliced weights and groups (as an integer)
weight
=
_S
(
weight
)[:,
:
int
(
in_channels_per_group
)]
groups
=
inputs
.
size
(
1
)
//
int
(
in_channels_per_group
)
# slice center
if
isinstance
(
kernel_size
,
dict
):
...
...
nni/retiarii/oneshot/pytorch/supermodule/sampling.py
View file @
2815fb1f
...
...
@@ -16,7 +16,7 @@ from nni.retiarii.nn.pytorch.api import ValueChoiceX
from
nni.retiarii.nn.pytorch.cell
import
CellOpFactory
,
create_cell_op_candidates
,
preprocess_cell_inputs
from
.base
import
BaseSuperNetModule
from
._valuechoice_utils
import
evaluate_value_choice_with_dict
,
dedup_inner_choices
from
._valuechoice_utils
import
evaluate_value_choice_with_dict
,
dedup_inner_choices
,
weighted_sum
from
.operation
import
MixedOperationSamplingPolicy
,
MixedOperation
__all__
=
[
...
...
@@ -72,6 +72,10 @@ class PathSamplingLayer(BaseSuperNetModule):
if
isinstance
(
module
,
LayerChoice
):
return
cls
(
list
(
module
.
named_children
()),
module
.
label
)
def
reduction
(
self
,
items
:
list
[
Any
],
sampled
:
list
[
Any
]):
"""Override this to implement customized reduction."""
return
weighted_sum
(
items
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one path needs to be sampled before fprop.'
)
...
...
@@ -79,10 +83,7 @@ class PathSamplingLayer(BaseSuperNetModule):
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
res
=
[
getattr
(
self
,
str
(
samp
))(
*
args
,
**
kwargs
)
for
samp
in
sampled
]
if
len
(
res
)
==
1
:
return
res
[
0
]
else
:
return
sum
(
res
)
return
self
.
reduction
(
res
,
sampled
)
class
PathSamplingInput
(
BaseSuperNetModule
):
...
...
@@ -95,11 +96,11 @@ class PathSamplingInput(BaseSuperNetModule):
Sampled input indices.
"""
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
,
reduction
:
str
,
label
:
str
):
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
int
,
reduction
_type
:
str
,
label
:
str
):
super
().
__init__
()
self
.
n_candidates
=
n_candidates
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
reduction
_type
=
reduction
_type
self
.
_sampled
:
list
[
int
]
|
int
|
None
=
None
self
.
label
=
label
...
...
@@ -144,6 +145,19 @@ class PathSamplingInput(BaseSuperNetModule):
raise
ValueError
(
'n_chosen is None is not supported yet.'
)
return
cls
(
module
.
n_candidates
,
module
.
n_chosen
,
module
.
reduction
,
module
.
label
)
def
reduction
(
self
,
items
:
list
[
Any
],
sampled
:
list
[
Any
])
->
Any
:
"""Override this to implement customized reduction."""
if
len
(
items
)
==
1
:
return
items
[
0
]
else
:
if
self
.
reduction_type
==
'sum'
:
return
sum
(
items
)
elif
self
.
reduction_type
==
'mean'
:
return
sum
(
items
)
/
len
(
items
)
elif
self
.
reduction_type
==
'concat'
:
return
torch
.
cat
(
items
,
1
)
raise
ValueError
(
f
'Unsupported reduction type:
{
self
.
reduction_type
}
'
)
def
forward
(
self
,
input_tensors
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one path needs to be sampled before fprop.'
)
...
...
@@ -151,15 +165,7 @@ class PathSamplingInput(BaseSuperNetModule):
raise
ValueError
(
f
'Expect
{
self
.
n_candidates
}
input tensors, found
{
len
(
input_tensors
)
}
.'
)
sampled
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
res
=
[
input_tensors
[
samp
]
for
samp
in
sampled
]
if
len
(
res
)
==
1
:
return
res
[
0
]
else
:
if
self
.
reduction
==
'sum'
:
return
sum
(
res
)
elif
self
.
reduction
==
'mean'
:
return
sum
(
res
)
/
len
(
res
)
elif
self
.
reduction
==
'concat'
:
return
torch
.
cat
(
res
,
1
)
return
self
.
reduction
(
res
,
sampled
)
class
MixedOpPathSamplingPolicy
(
MixedOperationSamplingPolicy
):
...
...
@@ -202,6 +208,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return
result
def
forward_argument
(
self
,
operation
:
MixedOperation
,
name
:
str
)
->
Any
:
# NOTE: we don't support sampling a list here.
if
self
.
_sampled
is
None
:
raise
ValueError
(
'Need to call resample() before running forward'
)
if
name
in
operation
.
mutable_arguments
:
...
...
@@ -257,20 +264,23 @@ class PathSamplingRepeat(BaseSuperNetModule):
# Only interesting when depth is mutable
return
cls
(
cast
(
List
[
nn
.
Module
],
module
.
blocks
),
module
.
depth_choice
)
def
reduction
(
self
,
items
:
list
[
Any
],
sampled
:
list
[
Any
]):
"""Override this to implement customized reduction."""
return
weighted_sum
(
items
)
def
forward
(
self
,
x
):
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one depth needs to be sampled before fprop.'
)
if
isinstance
(
self
.
_sampled
,
list
):
res
=
[]
for
i
,
block
in
enumerate
(
self
.
blocks
):
x
=
block
(
x
)
if
i
in
self
.
_sampled
:
res
.
append
(
x
)
return
sum
(
res
)
else
:
for
block
in
self
.
blocks
[:
self
.
_sampled
]:
x
=
block
(
x
)
return
x
sampled
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
res
=
[]
for
cur_depth
,
block
in
enumerate
(
self
.
blocks
,
start
=
1
):
x
=
block
(
x
)
if
cur_depth
in
sampled
:
res
.
append
(
x
)
if
not
any
(
d
>
cur_depth
for
d
in
sampled
):
break
return
self
.
reduction
(
res
,
sampled
)
class
PathSamplingCell
(
BaseSuperNetModule
):
...
...
test/ut/retiarii/test_oneshot.py
View file @
2815fb1f
...
...
@@ -215,7 +215,7 @@ def _mnist_net(type_, evaluator_kwargs):
base_model
=
CustomOpValueChoiceNet
()
else
:
raise
ValueError
(
f
'Unsupported type:
{
type_
}
'
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
nni
.
trace
(
MNIST
)(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
...
...
test/ut/retiarii/test_oneshot_supermodules.py
View file @
2815fb1f
...
...
@@ -78,6 +78,46 @@ def test_valuechoice_utils():
for
value
,
weight
in
ans
.
items
():
assert
abs
(
weight
-
weights
[
value
])
<
1e-6
assert
evaluate_constant
(
ValueChoice
([
3
,
4
,
6
],
label
=
'x'
)
-
ValueChoice
([
3
,
4
,
6
],
label
=
'x'
))
==
0
with
pytest
.
raises
(
ValueError
):
evaluate_constant
(
ValueChoice
([
3
,
4
,
6
])
-
ValueChoice
([
3
,
4
,
6
]))
assert
evaluate_constant
(
ValueChoice
([
3
,
4
,
6
],
label
=
'x'
)
*
2
/
ValueChoice
([
3
,
4
,
6
],
label
=
'x'
))
==
2
def
test_weighted_sum
():
weights
=
[
0.1
,
0.2
,
0.7
]
items
=
[
1
,
2
,
3
]
assert
abs
(
weighted_sum
(
items
,
weights
)
-
2.6
)
<
1e-6
assert
weighted_sum
(
items
)
==
6
with
pytest
.
raises
(
TypeError
,
match
=
'Unsupported'
):
weighted_sum
([
'a'
,
'b'
,
'c'
],
weights
)
assert
abs
(
weighted_sum
(
np
.
arange
(
3
),
weights
).
item
()
-
1.6
)
<
1e-6
items
=
[
torch
.
full
((
2
,
3
,
5
),
i
)
for
i
in
items
]
assert
abs
(
weighted_sum
(
items
,
weights
).
flatten
()[
0
].
item
()
-
2.6
)
<
1e-6
items
=
[
torch
.
randn
(
2
,
3
,
i
)
for
i
in
[
1
,
2
,
3
]]
with
pytest
.
raises
(
ValueError
,
match
=
r
'does not match.*\n.*torch\.Tensor\(2, 3, 1\)'
):
weighted_sum
(
items
,
weights
)
items
=
[(
1
,
2
),
(
3
,
4
),
(
5
,
6
)]
res
=
weighted_sum
(
items
,
weights
)
assert
len
(
res
)
==
2
and
abs
(
res
[
0
]
-
4.2
)
<
1e-6
and
abs
(
res
[
1
]
-
5.2
)
<
1e-6
items
=
[(
1
,
2
),
(
3
,
4
),
(
5
,
6
,
7
)]
with
pytest
.
raises
(
ValueError
):
weighted_sum
(
items
,
weights
)
items
=
[{
"a"
:
i
,
"b"
:
np
.
full
((
2
,
3
,
5
),
i
)}
for
i
in
[
1
,
2
,
3
]]
res
=
weighted_sum
(
items
,
weights
)
assert
res
[
'b'
].
shape
==
(
2
,
3
,
5
)
assert
abs
(
res
[
'b'
][
0
][
0
][
0
]
-
res
[
'a'
])
<
1e-6
assert
abs
(
res
[
'a'
]
-
2.6
)
<
1e-6
def
test_pathsampling_valuechoice
():
orig_conv
=
Conv2d
(
3
,
ValueChoice
([
3
,
5
,
7
],
label
=
'123'
),
kernel_size
=
3
)
...
...
@@ -147,6 +187,26 @@ def test_mixed_conv2d():
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
1
,
groups
=
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
))
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
6
},
torch
.
randn
(
2
,
6
,
10
,
10
)).
size
()
==
torch
.
Size
([
2
,
6
,
10
,
10
])
# groups, invalid case
conv
=
Conv2d
(
ValueChoice
([
9
,
6
,
3
],
label
=
'in'
),
ValueChoice
([
9
,
6
,
3
],
label
=
'in'
),
1
,
groups
=
9
)
with
pytest
.
raises
(
RuntimeError
):
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
6
},
torch
.
randn
(
2
,
6
,
10
,
10
))
# groups, differentiable
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
3
,
6
,
9
],
label
=
'out'
),
1
,
groups
=
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
))
_mixed_operation_differentiable_sanity_check
(
conv
,
torch
.
randn
(
2
,
9
,
3
,
3
))
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
1
,
groups
=
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
))
_mixed_operation_differentiable_sanity_check
(
conv
,
torch
.
randn
(
2
,
9
,
3
,
3
))
with
pytest
.
raises
(
ValueError
):
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
1
,
groups
=
ValueChoice
([
3
,
9
],
label
=
'groups'
))
_mixed_operation_differentiable_sanity_check
(
conv
,
torch
.
randn
(
2
,
9
,
3
,
3
))
with
pytest
.
raises
(
RuntimeError
):
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
1
,
groups
=
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
)
//
3
)
_mixed_operation_differentiable_sanity_check
(
conv
,
torch
.
randn
(
2
,
10
,
3
,
3
))
# make sure kernel is sliced correctly
conv
=
Conv2d
(
1
,
1
,
ValueChoice
([
1
,
3
],
label
=
'k'
),
bias
=
False
)
conv
=
MixedConv2d
.
mutate
(
conv
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
...
...
@@ -238,13 +298,18 @@ def test_differentiable_layer_input():
assert
op
.
export
({})[
'eee'
]
in
[
'a'
,
'b'
]
assert
len
(
list
(
op
.
parameters
()))
==
3
with
pytest
.
raises
(
ValueError
):
op
=
DifferentiableMixedLayer
([(
'a'
,
Linear
(
2
,
3
)),
(
'b'
,
Linear
(
2
,
4
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
op
(
torch
.
randn
(
4
,
2
))
input
=
DifferentiableMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
(
-
1
)
==
2
assert
len
(
input
.
export
({})[
'ddd'
])
==
2
def
test_proxyless_layer_input
():
op
=
ProxylessMixedLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
op
=
ProxylessMixedLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
assert
op
.
resample
({})[
'eee'
]
in
[
'a'
,
'b'
]
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
export
({})[
'eee'
]
in
[
'a'
,
'b'
]
...
...
@@ -286,6 +351,31 @@ def test_differentiable_repeat():
sample
=
op
.
export
({})
assert
'ccc'
in
sample
and
sample
[
'ccc'
]
in
[
0
,
1
]
class
TupleModule
(
nn
.
Module
):
def
__init__
(
self
,
num
):
super
().
__init__
()
self
.
num
=
num
def
forward
(
self
,
*
args
,
**
kwargs
):
return
torch
.
full
((
2
,
3
),
self
.
num
),
torch
.
full
((
3
,
5
),
self
.
num
),
{
'a'
:
7
,
'b'
:
[
self
.
num
]
*
11
}
class
CustomSoftmax
(
nn
.
Softmax
):
def
forward
(
self
,
*
args
,
**
kwargs
):
return
[
0.3
,
0.3
,
0.4
]
op
=
DifferentiableMixedRepeat
(
[
TupleModule
(
i
+
1
)
for
i
in
range
(
4
)],
ValueChoice
([
1
,
2
,
4
],
label
=
'ccc'
),
CustomSoftmax
(),
{}
)
op
.
resample
({})
res
=
op
(
None
)
assert
len
(
res
)
==
3
assert
res
[
0
].
shape
==
(
2
,
3
)
and
res
[
0
][
0
][
0
].
item
()
==
2.5
assert
res
[
2
][
'a'
]
==
7
assert
len
(
res
[
2
][
'b'
])
==
11
and
res
[
2
][
'b'
][
-
1
]
==
2.5
def
test_pathsampling_cell
():
for
cell_cls
in
[
CellSimple
,
CellDefaultArgs
,
CellCustomProcessor
,
CellLooseEnd
,
CellOpFactory
]:
...
...
@@ -363,4 +453,3 @@ def test_differentiable_cell():
else
:
# no loose-end support for now
assert
output
.
shape
==
torch
.
Size
([
2
,
16
*
model
.
cell
.
num_nodes
])
test/ut/retiarii/test_space_hub.py
View file @
2815fb1f
...
...
@@ -95,7 +95,7 @@ def test_nasbench101():
def
test_nasbench201
():
ss
=
searchspace
.
NasBench
1
01
()
ss
=
searchspace
.
NasBench
2
01
()
_test_searchspace_on_dataset
(
ss
)
...
...
test/ut/retiarii/test_space_hub_oneshot.py
0 → 100644
View file @
2815fb1f
import
logging
import
pytest
import
numpy
as
np
import
torch
import
nni
import
nni.retiarii.hub.pytorch
as
ss
import
nni.retiarii.evaluator.pytorch
as
pl
import
nni.retiarii.strategy
as
stg
from
nni.retiarii.experiment.pytorch
import
RetiariiExperiment
,
RetiariiExeConfig
from
nni.retiarii.hub.pytorch.nasnet
import
NDSStagePathSampling
,
NDSStageDifferentiable
from
torch.utils.data
import
Subset
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
,
ImageNet
pytestmark
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'Too slow without CUDA.'
)
def
_hub_factory
(
alias
):
if
alias
==
'nasbench101'
:
return
ss
.
NasBench101
()
if
alias
==
'nasbench201'
:
return
ss
.
NasBench201
()
if
alias
==
'mobilenetv3'
:
return
ss
.
MobileNetV3Space
()
if
alias
==
'mobilenetv3_small'
:
return
ss
.
MobileNetV3Space
(
width_multipliers
=
(
0.75
,
1
,
1.5
),
expand_ratios
=
(
4
,
6
)
)
if
alias
==
'proxylessnas'
:
return
ss
.
ProxylessNAS
()
if
alias
==
'shufflenet'
:
return
ss
.
ShuffleNetSpace
()
if
alias
==
'autoformer'
:
return
ss
.
AutoformerSpace
()
if
'_smalldepth'
in
alias
:
num_cells
=
(
4
,
8
)
elif
'_depth'
in
alias
:
num_cells
=
(
8
,
12
)
else
:
num_cells
=
8
if
'_width'
in
alias
:
width
=
(
8
,
16
)
else
:
width
=
16
if
'_imagenet'
in
alias
:
dataset
=
'imagenet'
else
:
dataset
=
'cifar'
if
alias
.
startswith
(
'nasnet'
):
return
ss
.
NASNet
(
width
=
width
,
num_cells
=
num_cells
,
dataset
=
dataset
)
if
alias
.
startswith
(
'enas'
):
return
ss
.
ENAS
(
width
=
width
,
num_cells
=
num_cells
,
dataset
=
dataset
)
if
alias
.
startswith
(
'amoeba'
):
return
ss
.
AmoebaNet
(
width
=
width
,
num_cells
=
num_cells
,
dataset
=
dataset
)
if
alias
.
startswith
(
'pnas'
):
return
ss
.
PNAS
(
width
=
width
,
num_cells
=
num_cells
,
dataset
=
dataset
)
if
alias
.
startswith
(
'darts'
):
return
ss
.
DARTS
(
width
=
width
,
num_cells
=
num_cells
,
dataset
=
dataset
)
raise
ValueError
(
f
'Unrecognized space:
{
alias
}
'
)
def
_strategy_factory
(
alias
,
space_type
):
# Some search space needs extra hooks
extra_mutation_hooks
=
[]
nds_need_shape_alignment
=
'_smalldepth'
in
space_type
if
nds_need_shape_alignment
:
if
alias
in
[
'enas'
,
'random'
]:
extra_mutation_hooks
.
append
(
NDSStagePathSampling
.
mutate
)
else
:
extra_mutation_hooks
.
append
(
NDSStageDifferentiable
.
mutate
)
if
alias
==
'darts'
:
return
stg
.
DARTS
(
mutation_hooks
=
extra_mutation_hooks
)
if
alias
==
'gumbel'
:
return
stg
.
GumbelDARTS
(
mutation_hooks
=
extra_mutation_hooks
)
if
alias
==
'proxyless'
:
return
stg
.
Proxyless
()
if
alias
==
'enas'
:
return
stg
.
ENAS
(
mutation_hooks
=
extra_mutation_hooks
,
reward_metric_name
=
'val_acc'
)
if
alias
==
'random'
:
return
stg
.
RandomOneShot
(
mutation_hooks
=
extra_mutation_hooks
)
raise
ValueError
(
f
'Unrecognized strategy:
{
alias
}
'
)
def
_dataset_factory
(
dataset_type
,
subset
=
20
):
if
dataset_type
==
'cifar10'
:
normalize
=
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
train_dataset
=
nni
.
trace
(
CIFAR10
)(
'../data/cifar10'
,
train
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomCrop
(
32
,
4
),
transforms
.
ToTensor
(),
normalize
,
]))
valid_dataset
=
nni
.
trace
(
CIFAR10
)(
'../data/cifar10'
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
normalize
,
]))
elif
dataset_type
==
'imagenet'
:
normalize
=
transforms
.
Normalize
(
mean
=
[
0.485
,
0.456
,
0.406
],
std
=
[
0.229
,
0.224
,
0.225
])
train_dataset
=
nni
.
trace
(
ImageNet
)(
'../data/imagenet'
,
split
=
'val'
,
# no train data available in tests
transform
=
transforms
.
Compose
([
transforms
.
RandomResizedCrop
(
224
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
normalize
,
]))
valid_dataset
=
nni
.
trace
(
ImageNet
)(
'../data/imagenet'
,
split
=
'val'
,
transform
=
transforms
.
Compose
([
transforms
.
Resize
(
256
),
transforms
.
CenterCrop
(
224
),
transforms
.
ToTensor
(),
normalize
,
]))
else
:
raise
ValueError
(
f
'Unsupported dataset type:
{
dataset_type
}
'
)
if
subset
:
train_dataset
=
Subset
(
train_dataset
,
np
.
random
.
permutation
(
len
(
train_dataset
))[:
subset
])
valid_dataset
=
Subset
(
valid_dataset
,
np
.
random
.
permutation
(
len
(
valid_dataset
))[:
subset
])
return
train_dataset
,
valid_dataset
@
pytest
.
mark
.
parametrize
(
'space_type'
,
[
# 'nasbench101',
'nasbench201'
,
'mobilenetv3'
,
'mobilenetv3_small'
,
'proxylessnas'
,
'shufflenet'
,
# 'autoformer',
'nasnet'
,
'enas'
,
'amoeba'
,
'pnas'
,
'darts'
,
'darts_smalldepth'
,
'darts_depth'
,
'darts_width'
,
'darts_width_smalldepth'
,
'darts_width_depth'
,
'darts_imagenet'
,
'darts_width_smalldepth_imagenet'
,
'enas_smalldepth'
,
'enas_depth'
,
'enas_width'
,
'enas_width_smalldepth'
,
'enas_width_depth'
,
'enas_imagenet'
,
'enas_width_smalldepth_imagenet'
,
'pnas_width_smalldepth'
,
'amoeba_width_smalldepth'
,
])
@
pytest
.
mark
.
parametrize
(
'strategy_type'
,
[
'darts'
,
'gumbel'
,
'proxyless'
,
'enas'
,
'random'
])
def
test_hub_oneshot
(
space_type
,
strategy_type
):
NDS_SPACES
=
[
'amoeba'
,
'darts'
,
'pnas'
,
'enas'
,
'nasnet'
]
if
strategy_type
==
'proxyless'
:
if
'width'
in
space_type
or
'depth'
in
space_type
or
\
any
(
space_type
.
startswith
(
prefix
)
for
prefix
in
NDS_SPACES
+
[
'proxylessnas'
,
'mobilenetv3'
]):
pytest
.
skip
(
'The space has used unsupported APIs.'
)
if
strategy_type
in
[
'darts'
,
'gumbel'
]
and
space_type
==
'mobilenetv3'
:
pytest
.
skip
(
'Skip as it consumes too much memory.'
)
model_space
=
_hub_factory
(
space_type
)
dataset_type
=
'cifar10'
if
'imagenet'
in
space_type
or
space_type
in
[
'mobilenetv3'
,
'proxylessnas'
,
'shufflenet'
,
'autoformer'
]:
dataset_type
=
'imagenet'
subset_size
=
4
if
strategy_type
in
[
'darts'
,
'gumbel'
]
and
any
(
space_type
.
startswith
(
prefix
)
for
prefix
in
NDS_SPACES
)
and
'_'
in
space_type
:
subset_size
=
2
train_dataset
,
valid_dataset
=
_dataset_factory
(
dataset_type
,
subset
=
subset_size
)
train_loader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
2
,
num_workers
=
2
,
shuffle
=
True
)
valid_loader
=
pl
.
DataLoader
(
valid_dataset
,
batch_size
=
2
,
num_workers
=
2
,
shuffle
=
False
)
evaluator
=
pl
.
Classification
(
train_dataloaders
=
train_loader
,
val_dataloaders
=
valid_loader
,
max_epochs
=
1
,
export_onnx
=
False
,
gpus
=
1
if
torch
.
cuda
.
is_available
()
else
0
,
# 0 for my debug
logger
=
False
,
# disable logging and checkpoint to avoid too much log
enable_checkpointing
=
False
,
enable_model_summary
=
False
# profiler='advanced'
)
# To test on final model:
# model = type(model_space).load_searched_model('darts-v2')
# evaluator.fit(model)
strategy
=
_strategy_factory
(
strategy_type
,
space_type
)
config
=
RetiariiExeConfig
()
config
.
execution_engine
=
'oneshot'
experiment
=
RetiariiExperiment
(
model_space
,
evaluator
,
strategy
=
strategy
)
experiment
.
run
(
config
)
_original_loglevel
=
None
def
setup_module
(
module
):
global
_original_loglevel
_original_loglevel
=
logging
.
getLogger
(
"pytorch_lightning"
).
level
logging
.
getLogger
(
"pytorch_lightning"
).
setLevel
(
logging
.
WARNING
)
def
teardown_module
(
module
):
logging
.
getLogger
(
"pytorch_lightning"
).
setLevel
(
_original_loglevel
)
test/vso_tools/pack_testdata.py
View file @
2815fb1f
...
...
@@ -50,14 +50,16 @@ def prepare_imagenet_subset(data_dir: Path, imagenet_dir: Path):
# Target root dir
subset_dir
=
data_dir
/
'imagenet'
shutil
.
rmtree
(
subset_dir
,
ignore_errors
=
True
)
subset_dir
.
mkdir
(
parents
=
True
)
shutil
.
copyfile
(
imagenet_dir
/
'meta.bin'
,
subset_dir
/
'meta.bin'
)
copied_count
=
0
for
category_id
,
imgs
in
images
.
items
():
random_state
.
shuffle
(
imgs
)
for
img
in
imgs
[:
len
(
imgs
)
//
10
]:
folder_name
=
Path
(
img
).
parent
.
name
file_name
=
Path
(
img
).
name
(
subset_dir
/
folder_name
).
mkdir
(
exist_ok
=
True
,
parents
=
True
)
shutil
.
copyfile
(
img
,
subset_dir
/
folder_name
/
file_name
)
(
subset_dir
/
'val'
/
folder_name
).
mkdir
(
exist_ok
=
True
,
parents
=
True
)
shutil
.
copyfile
(
img
,
subset_dir
/
'val'
/
folder_name
/
file_name
)
copied_count
+=
1
print
(
f
'Generated a subset of
{
copied_count
}
images.'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment