Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6cb5916f
Unverified
Commit
6cb5916f
authored
Apr 24, 2020
by
Yuge Zhang
Committed by
GitHub
Apr 24, 2020
Browse files
Support OrderedDict for LayerChoice (#2336)
parent
319ff036
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
25 deletions
+55
-25
examples/nas/classic_nas/mnist.py
examples/nas/classic_nas/mnist.py
+10
-7
examples/nas/darts/model.py
examples/nas/darts/model.py
+11
-11
examples/nas/spos/network.py
examples/nas/spos/network.py
+0
-1
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+33
-5
No files found.
examples/nas/classic_nas/mnist.py
View file @
6cb5916f
...
@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py
...
@@ -8,6 +8,8 @@ https://github.com/pytorch/examples/blob/master/mnist/main.py
import
os
import
os
import
argparse
import
argparse
import
logging
import
logging
from
collections
import
OrderedDict
import
nni
import
nni
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -26,13 +28,15 @@ class Net(nn.Module):
...
@@ -26,13 +28,15 @@ class Net(nn.Module):
def
__init__
(
self
,
hidden_size
):
def
__init__
(
self
,
hidden_size
):
super
(
Net
,
self
).
__init__
()
super
(
Net
,
self
).
__init__
()
# two options of conv1
# two options of conv1
self
.
conv1
=
LayerChoice
([
nn
.
Conv2d
(
1
,
20
,
5
,
1
),
self
.
conv1
=
LayerChoice
(
OrderedDict
([
nn
.
Conv2d
(
1
,
20
,
3
,
1
)],
(
"conv5x5"
,
nn
.
Conv2d
(
1
,
20
,
5
,
1
)),
key
=
'first_conv'
)
(
"conv3x3"
,
nn
.
Conv2d
(
1
,
20
,
3
,
1
))
]),
key
=
'first_conv'
)
# two options of mid_conv
# two options of mid_conv
self
.
mid_conv
=
LayerChoice
([
nn
.
Conv2d
(
20
,
20
,
3
,
1
,
padding
=
1
),
self
.
mid_conv
=
LayerChoice
([
nn
.
Conv2d
(
20
,
20
,
5
,
1
,
padding
=
2
)],
nn
.
Conv2d
(
20
,
20
,
3
,
1
,
padding
=
1
),
key
=
'mid_conv'
)
nn
.
Conv2d
(
20
,
20
,
5
,
1
,
padding
=
2
)
],
key
=
'mid_conv'
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
20
,
50
,
5
,
1
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
)
self
.
fc1
=
nn
.
Linear
(
4
*
4
*
50
,
hidden_size
)
self
.
fc2
=
nn
.
Linear
(
hidden_size
,
10
)
self
.
fc2
=
nn
.
Linear
(
hidden_size
,
10
)
...
@@ -167,7 +171,6 @@ def get_params():
...
@@ -167,7 +171,6 @@ def get_params():
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
1000
,
metavar
=
'N'
,
help
=
'how many batches to wait before logging training status'
)
help
=
'how many batches to wait before logging training status'
)
args
,
_
=
parser
.
parse_known_args
()
args
,
_
=
parser
.
parse_known_args
()
return
args
return
args
...
...
examples/nas/darts/model.py
View file @
6cb5916f
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -43,17 +45,15 @@ class Node(nn.Module):
...
@@ -43,17 +45,15 @@ class Node(nn.Module):
stride
=
2
if
i
<
num_downsample_connect
else
1
stride
=
2
if
i
<
num_downsample_connect
else
1
choice_keys
.
append
(
"{}_p{}"
.
format
(
node_id
,
i
))
choice_keys
.
append
(
"{}_p{}"
.
format
(
node_id
,
i
))
self
.
ops
.
append
(
self
.
ops
.
append
(
mutables
.
LayerChoice
(
mutables
.
LayerChoice
(
OrderedDict
([
[
(
"maxpool"
,
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
ops
.
PoolBN
(
'max'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
(
"avgpool"
,
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
ops
.
PoolBN
(
'avg'
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
(
"skipconnect"
,
nn
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
)),
nn
.
Identity
()
if
stride
==
1
else
ops
.
FactorizedReduce
(
channels
,
channels
,
affine
=
False
),
(
"sepconv3x3"
,
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
)),
ops
.
SepConv
(
channels
,
channels
,
3
,
stride
,
1
,
affine
=
False
),
(
"sepconv5x5"
,
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
)),
ops
.
SepConv
(
channels
,
channels
,
5
,
stride
,
2
,
affine
=
False
),
(
"dilconv3x3"
,
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
)),
ops
.
DilConv
(
channels
,
channels
,
3
,
stride
,
2
,
2
,
affine
=
False
),
(
"dilconv5x5"
,
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
))
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
]),
key
=
choice_keys
[
-
1
]))
],
key
=
choice_keys
[
-
1
]))
self
.
drop_path
=
ops
.
DropPath
()
self
.
drop_path
=
ops
.
DropPath
()
self
.
input_switch
=
mutables
.
InputChoice
(
choose_from
=
choice_keys
,
n_chosen
=
2
,
key
=
"{}_switch"
.
format
(
node_id
))
self
.
input_switch
=
mutables
.
InputChoice
(
choose_from
=
choice_keys
,
n_chosen
=
2
,
key
=
"{}_switch"
.
format
(
node_id
))
...
...
examples/nas/spos/network.py
View file @
6cb5916f
...
@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
...
@@ -151,6 +151,5 @@ def load_and_parse_state_dict(filepath="./data/checkpoint-150000.pth.tar"):
for
k
,
v
in
checkpoint
[
"state_dict"
].
items
():
for
k
,
v
in
checkpoint
[
"state_dict"
].
items
():
if
k
.
startswith
(
"module."
):
if
k
.
startswith
(
"module."
):
k
=
k
[
len
(
"module."
):]
k
=
k
[
len
(
"module."
):]
k
=
re
.
sub
(
r
"^(features.\d+).(\d+)"
,
"
\\
1.choices.
\\
2"
,
k
)
result
[
k
]
=
v
result
[
k
]
=
v
return
result
return
result
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
View file @
6cb5916f
...
@@ -203,7 +203,7 @@ class ClassicMutator(Mutator):
...
@@ -203,7 +203,7 @@ class ClassicMutator(Mutator):
# for now we only generate flattened search space
# for now we only generate flattened search space
if
isinstance
(
mutable
,
LayerChoice
):
if
isinstance
(
mutable
,
LayerChoice
):
key
=
mutable
.
key
key
=
mutable
.
key
val
=
[
repr
(
choice
)
for
choice
in
mutable
.
choic
es
]
val
=
mutable
.
nam
es
search_space
[
key
]
=
{
"_type"
:
LAYER_CHOICE
,
"_value"
:
val
}
search_space
[
key
]
=
{
"_type"
:
LAYER_CHOICE
,
"_value"
:
val
}
elif
isinstance
(
mutable
,
InputChoice
):
elif
isinstance
(
mutable
,
InputChoice
):
key
=
mutable
.
key
key
=
mutable
.
key
...
...
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
6cb5916f
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
from
collections
import
OrderedDict
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -83,9 +84,6 @@ class Mutable(nn.Module):
...
@@ -83,9 +84,6 @@ class Mutable(nn.Module):
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
return
"{} ({})"
.
format
(
self
.
name
,
self
.
key
)
class
MutableScope
(
Mutable
):
class
MutableScope
(
Mutable
):
"""
"""
...
@@ -128,7 +126,7 @@ class LayerChoice(Mutable):
...
@@ -128,7 +126,7 @@ class LayerChoice(Mutable):
Parameters
Parameters
----------
----------
op_candidates : list of nn.Module
op_candidates : list of nn.Module
or OrderedDict
A module list to be selected from.
A module list to be selected from.
reduction : str
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
...
@@ -143,12 +141,42 @@ class LayerChoice(Mutable):
...
@@ -143,12 +141,42 @@ class LayerChoice(Mutable):
----------
----------
length : int
length : int
Number of ops to choose from.
Number of ops to choose from.
names: list of str
Names of candidates.
Notes
-----
``op_candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
"""
"""
def
__init__
(
self
,
op_candidates
,
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
def
__init__
(
self
,
op_candidates
,
reduction
=
"sum"
,
return_mask
=
False
,
key
=
None
):
super
().
__init__
(
key
=
key
)
super
().
__init__
(
key
=
key
)
self
.
length
=
len
(
op_candidates
)
self
.
length
=
len
(
op_candidates
)
self
.
choices
=
nn
.
ModuleList
(
op_candidates
)
self
.
choices
=
[]
self
.
names
=
[]
if
isinstance
(
op_candidates
,
OrderedDict
):
for
name
,
module
in
op_candidates
.
items
():
assert
name
not
in
[
"length"
,
"reduction"
,
"return_mask"
,
"_key"
,
"key"
,
"names"
],
\
"Please don't use a reserved name '{}' for your module."
.
format
(
name
)
self
.
add_module
(
name
,
module
)
self
.
choices
.
append
(
module
)
self
.
names
.
append
(
name
)
elif
isinstance
(
op_candidates
,
list
):
for
i
,
module
in
enumerate
(
op_candidates
):
self
.
add_module
(
str
(
i
),
module
)
self
.
choices
.
append
(
module
)
self
.
names
.
append
(
str
(
i
))
else
:
raise
TypeError
(
"Unsupported op_candidates type: {}"
.
format
(
type
(
op_candidates
)))
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment