Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6cb5916f
Unverified
Commit
6cb5916f
authored
Apr 24, 2020
by
Yuge Zhang
Committed by
GitHub
Apr 24, 2020
Browse files
Support OrderedDict for LayerChoice (#2336)
parent
319ff036
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
25 deletions
+55
-25
examples/nas/classic_nas/mnist.py
examples/nas/classic_nas/mnist.py
+10
-7
examples/nas/darts/model.py
examples/nas/darts/model.py
+11
-11
examples/nas/spos/network.py
examples/nas/spos/network.py
+0
-1
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+33
-5
No files found.
examples/nas/classic_nas/mnist.py
View file @
6cb5916f
...
...
@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py
import
os
import
argparse
import
logging
from
collections
import
OrderedDict
import
nni
import
torch
import
torch.nn
as
nn
...
...
@@ -26,13 +28,15 @@ class Net(nn.Module):
def
__init__
(
self
,
hidden_size
):
super
(
Net
,
self
).
__init__
()
# two options of conv1
self
.
conv1
=
LayerChoice
([
nn
.
Conv2d
(
1
,
20
,
5
,
1
),
nn
.
Conv2d
(
1
,
20
,
3
,
1
)],
key
=
'first_conv'
)
self
.
conv1
=
LayerChoice
(
OrderedDict
([
(
"conv5x5"
,
nn
.
Conv2d
(
1
,
20
,
5
,
1
)),
(
"conv3x3"
,
nn
.
Conv2d
(
1
,
20
,
3
,
1
))
]),
key
=
'first_conv'
)
# two options of mid_conv
self
.
mid_conv
=
LayerChoice
([
nn
.
Conv2d
(
20
,
20
,
3
,
1
,
padding
=
1
),
nn
.
Conv2d
(
20
,
20
,
5
,
1
,
padding
=
2
)],
key
=
'mid_conv'
)
self
.
mid_conv
=
LayerChoice
([
nn
.
Conv2d
(
20
,
20
,
3
,
1
,
padding
=
1
),
nn
.
Conv2d
(
20
,
20
,
5
,
1
,
padding
=
2
)
],
key
=
'mid_conv'
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
)
self
.
fc2
=
nn
.
Linear
(
hidden_size
,
10
)
...
...
@@ -167,7 +171,6 @@ def get_params():
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
help
=
'how many batches to wait before logging training status'
)
args
,
_
=
parser
.
parse_known_args
()
return
args
...
...
examples/nas/darts/model.py
View file @
6cb5916f
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
...
...
@@ -43,17 +45,15 @@ class Node(nn.Module):
stride
=
2
if
i
<
num_downsample_connect
else
1
choice_keys
.
append
(
"{}_p{}"
.
format
(
node_id
,
i
))
self
.
ops
.
append
(
mutables
.
LayerChoice
(
[
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
nn
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
],
key
=
choice_keys
[
-
1
]))
mutables
.
LayerChoice
(
OrderedDict
([
(
"maxpool"
,
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
(
"avgpool"
,
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
(
"skipconnect"
,
nn
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
)),
(
"sepconv3x3"
,
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
(
"sepconv5x5"
,
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
)),
(
"dilconv3x3"
,
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
)),
(
"dilconv5x5"
,
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
))
]),
key
=
choice_keys
[
-
1
]))
self
.
drop_path
=
ops
.
DropPath
()
self
.
input_switch
=
mutables
.
InputChoice
(
choose_from
=
choice_keys
,
n_chosen
=
2
,
key
=
"{}_switch"
.
format
(
node_id
))
...
...
examples/nas/spos/network.py
View file @
6cb5916f
...
...
@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
for
k
,
v
in
checkpoint
[
"state_dict"
].
items
():
if
k
.
startswith
(
"module."
):
k
=
k
[
len
(
"module."
):]
k
=
re
.
sub
(
r
"^(features.\d+).(\d+)"
,
"
\\
1.choices.
\\
2"
,
k
)
result
[
k
]
=
v
return
result
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
View file @
6cb5916f
...
...
@@ -203,7 +203,7 @@ class ClassicMutator(Mutator):
# for now we only generate flattened search space
if
isinstance
(
mutable
,
LayerChoice
):
key
=
mutable
.
key
val
=
[
repr
(
choice
)
for
choice
in
mutable
.
choic
es
]
val
=
mutable
.
nam
es
search_space
[
key
]
=
{
"_type"
:
LAYER_CHOICE
,
"_value"
:
val
}
elif
isinstance
(
mutable
,
InputChoice
):
key
=
mutable
.
key
...
...
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
6cb5916f
...
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import
logging
from
collections
import
OrderedDict
import
torch.nn
as
nn
...
...
@@ -83,9 +84,6 @@ class Mutable(nn.Module):
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
"{} ({})"
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
Mutable
):
"""
...
...
@@ -128,7 +126,7 @@ class LayerChoice(Mutable):
Parameters
----------
op_candidates : list of nn.Module
op_candidates : list of nn.Module
or OrderedDict
A module list to be selected from.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
...
...
@@ -143,12 +141,42 @@ class LayerChoice(Mutable):
----------
length : int
Number of ops to choose from.
names: list of str
Names of candidates.
Notes
-----
``op_candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
"""
def
__init__
(
self
,
op_candidates
,
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
self
.
length
=
len
(
op_candidates
)
self
.
choices
=
nn
.
ModuleList
(
op_candidates
)
self
.
choices
=
[]
self
.
names
=
[]
if
isinstance
(
op_candidates
,
OrderedDict
):
for
name
,
module
in
op_candidates
.
items
():
assert
name
not
in
[
"length"
,
"reduction"
,
"return_mask"
,
"_key"
,
"key"
,
"names"
],
\
"Please don't use a reserved name '{}' for your module."
.
format
(
name
)
self
.
add_module
(
name
,
module
)
self
.
choices
.
append
(
module
)
self
.
names
.
append
(
name
)
elif
isinstance
(
op_candidates
,
list
):
for
i
,
module
in
enumerate
(
op_candidates
):
self
.
add_module
(
str
(
i
),
module
)
self
.
choices
.
append
(
module
)
self
.
names
.
append
(
str
(
i
))
else
:
raise
TypeError
(
"Unsupported op_candidates type: {}"
.
format
(
type
(
op_candidates
)))
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment