Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b49b38f8
Unverified
Commit
b49b38f8
authored
Feb 07, 2020
by
Yuge Zhang
Committed by
GitHub
Feb 07, 2020
Browse files
Add unit tests for NAS (#1954)
parent
74250987
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
334 additions
and
26 deletions
+334
-26
azure-pipelines.yml
azure-pipelines.yml
+5
-5
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
+9
-0
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
+19
-14
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+6
-6
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+8
-0
src/sdk/pynni/tests/models/pytorch_models/__init__.py
src/sdk/pynni/tests/models/pytorch_models/__init__.py
+6
-0
src/sdk/pynni/tests/models/pytorch_models/mutable_scope.py
src/sdk/pynni/tests/models/pytorch_models/mutable_scope.py
+95
-0
src/sdk/pynni/tests/models/pytorch_models/naive.py
src/sdk/pynni/tests/models/pytorch_models/naive.py
+45
-0
src/sdk/pynni/tests/models/pytorch_models/nested.py
src/sdk/pynni/tests/models/pytorch_models/nested.py
+34
-0
src/sdk/pynni/tests/test_nas.py
src/sdk/pynni/tests/test_nas.py
+106
-0
test/pipelines-it-local-windows.yml
test/pipelines-it-local-windows.yml
+1
-1
No files found.
azure-pipelines.yml
View file @
b49b38f8
...
...
@@ -26,8 +26,8 @@ jobs:
yarn eslint
displayName
:
'
Run
eslint'
-
script
:
|
python3 -m pip install torch==
0.4.1
--user
python3 -m pip install torchvision==0.
2.1
--user
python3 -m pip install torch==
1.2.0
--user
python3 -m pip install torchvision==0.
4.0
--user
python3 -m pip install tensorflow==1.13.1 --user
python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx --user
...
...
@@ -91,8 +91,8 @@ jobs:
echo "##vso[task.setvariable variable=PATH]${HOME}/Library/Python/3.7/bin:${PATH}"
displayName
:
'
Install
nni
toolkit
via
source
code'
-
script
:
|
python3 -m pip install torch==
0.4.1
--user
python3 -m pip install torchvision==0.
2.1
--user
python3 -m pip install torch==
1.2.0
--user
python3 -m pip install torchvision==0.
4.0
--user
python3 -m pip install tensorflow==1.13.1 --user
ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" < /dev/null 2> /dev/null
brew install swig@3
...
...
@@ -131,7 +131,7 @@ jobs:
-
script
:
|
python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user
python -m pip install
https://download.pytorch.org/whl/cu90/torch-0.4.1-cp36-cp36m-win_amd64.wh
l --user
python -m pip install
torch===1.2.0 torchvision===0.4.1 -f https://download.pytorch.org/whl/torch_stable.htm
l --user
python -m pip install torchvision --user
python -m pip install tensorflow==1.13.1 --user
displayName
:
'
Install
dependencies'
...
...
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
View file @
b49b38f8
...
...
@@ -67,6 +67,13 @@ class ClassicMutator(Mutator):
else
:
# get chosen arch from tuner
self
.
_chosen_arch
=
nni
.
get_next_parameter
()
if
self
.
_chosen_arch
is
None
:
if
trial_env_vars
.
NNI_PLATFORM
==
"unittest"
:
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger
.
warning
(
"`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode."
)
self
.
_chosen_arch
=
self
.
_standalone_generate_chosen
()
else
:
raise
RuntimeError
(
"Chosen architecture is None. This may be a platform error."
)
self
.
reset
()
def
_sample_layer_choice
(
self
,
mutable
,
idx
,
value
,
search_space_item
):
...
...
@@ -162,6 +169,8 @@ class ClassicMutator(Mutator):
elif
val
[
"_type"
]
==
INPUT_CHOICE
:
choices
=
val
[
"_value"
][
"candidates"
]
n_chosen
=
val
[
"_value"
][
"n_chosen"
]
if
n_chosen
is
None
:
n_chosen
=
len
(
choices
)
chosen_arch
[
key
]
=
{
"_value"
:
choices
[:
n_chosen
],
"_idx"
:
list
(
range
(
n_chosen
))}
else
:
raise
ValueError
(
"Unknown key '%s' and value '%s'."
%
(
key
,
val
))
...
...
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
View file @
b49b38f8
...
...
@@ -63,7 +63,8 @@ class DartsMutator(Mutator):
edges_max
[
mutable
.
key
]
=
max_val
result
[
mutable
.
key
]
=
F
.
one_hot
(
index
,
num_classes
=
mutable
.
length
).
view
(
-
1
).
bool
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
InputChoice
)
and
mutable
.
n_chosen
is
not
None
:
if
isinstance
(
mutable
,
InputChoice
):
if
mutable
.
n_chosen
is
not
None
:
weights
=
[]
for
src_key
in
mutable
.
choose_from
:
if
src_key
not
in
edges_max
:
...
...
@@ -74,7 +75,11 @@ class DartsMutator(Mutator):
selected_multihot
=
[]
for
i
,
src_key
in
enumerate
(
mutable
.
choose_from
):
if
i
not
in
topk_edge_indices
and
src_key
in
result
:
result
[
src_key
]
=
torch
.
zeros_like
(
result
[
src_key
])
# clear this choice to optimize calc graph
# If an edge is never selected, there is no need to calculate any op on this edge.
# This is to eliminate redundant calculation.
result
[
src_key
]
=
torch
.
zeros_like
(
result
[
src_key
])
selected_multihot
.
append
(
i
in
topk_edge_indices
)
result
[
mutable
.
key
]
=
torch
.
tensor
(
selected_multihot
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
# pylint: disable=not-callable
else
:
result
[
mutable
.
key
]
=
torch
.
ones
(
mutable
.
n_candidates
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
# pylint: disable=not-callable
return
result
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
b49b38f8
...
...
@@ -52,24 +52,24 @@ def _encode_tensor(data):
return
data
def
apply_fixed_architecture
(
model
,
fixed_arc
_path
):
def
apply_fixed_architecture
(
model
,
fixed_arc
):
"""
Load architecture from `fixed_arc
_path
` and apply to model.
Load architecture from `fixed_arc` and apply to model.
Parameters
----------
model : torch.nn.Module
Model with mutables.
fixed_arc
_path
: str
Path to the JSON that stores the architecture.
fixed_arc : str
or dict
Path to the JSON that stores the
architecture, or dict that stores the exported
architecture.
Returns
-------
FixedArchitecture
"""
if
isinstance
(
fixed_arc
_path
,
str
):
with
open
(
fixed_arc
_path
,
"r"
)
as
f
:
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
...
...
src/sdk/pynni/nni/nas/pytorch/utils.py
View file @
b49b38f8
...
...
@@ -17,6 +17,14 @@ def global_mutable_counting():
return
_counter
def
_reset_global_mutable_counting
():
"""
Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys.
"""
global
_counter
_counter
=
0
def
to_device
(
obj
,
device
):
if
torch
.
is_tensor
(
obj
):
return
obj
.
to
(
device
)
...
...
src/sdk/pynni/tests/models/pytorch_models/__init__.py
0 → 100644
View file @
b49b38f8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.mutable_scope
import
SpaceWithMutableScope
from
.naive
import
NaiveSearchSpace
from
.nested
import
NestedSpace
src/sdk/pynni/tests/models/pytorch_models/mutable_scope.py
0 → 100644
View file @
b49b38f8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
,
MutableScope
class
Cell
(
MutableScope
):
def
__init__
(
self
,
cell_name
,
prev_labels
,
channels
):
super
().
__init__
(
cell_name
)
self
.
input_choice
=
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
1
,
return_mask
=
True
,
key
=
cell_name
+
"_input"
)
self
.
op_choice
=
LayerChoice
([
nn
.
Conv2d
(
channels
,
channels
,
3
,
padding
=
1
),
nn
.
Conv2d
(
channels
,
channels
,
5
,
padding
=
2
),
nn
.
MaxPool2d
(
3
,
stride
=
1
,
padding
=
1
),
nn
.
AvgPool2d
(
3
,
stride
=
1
,
padding
=
1
),
nn
.
Identity
()
],
key
=
cell_name
+
"_op"
)
def
forward
(
self
,
prev_layers
):
chosen_input
,
chosen_mask
=
self
.
input_choice
(
prev_layers
)
cell_out
=
self
.
op_choice
(
chosen_input
)
return
cell_out
,
chosen_mask
class
Node
(
MutableScope
):
def
__init__
(
self
,
node_name
,
prev_node_names
,
channels
):
super
().
__init__
(
node_name
)
self
.
cell_x
=
Cell
(
node_name
+
"_x"
,
prev_node_names
,
channels
)
self
.
cell_y
=
Cell
(
node_name
+
"_y"
,
prev_node_names
,
channels
)
def
forward
(
self
,
prev_layers
):
out_x
,
mask_x
=
self
.
cell_x
(
prev_layers
)
out_y
,
mask_y
=
self
.
cell_y
(
prev_layers
)
return
out_x
+
out_y
,
mask_x
|
mask_y
class
Layer
(
nn
.
Module
):
def
__init__
(
self
,
num_nodes
,
channels
):
super
().
__init__
()
self
.
num_nodes
=
num_nodes
self
.
nodes
=
nn
.
ModuleList
()
node_labels
=
[
InputChoice
.
NO_KEY
,
InputChoice
.
NO_KEY
]
for
i
in
range
(
num_nodes
):
node_labels
.
append
(
"node_{}"
.
format
(
i
))
self
.
nodes
.
append
(
Node
(
node_labels
[
-
1
],
node_labels
[:
-
1
],
channels
))
self
.
final_conv_w
=
nn
.
Parameter
(
torch
.
zeros
(
channels
,
self
.
num_nodes
+
2
,
channels
,
1
,
1
),
requires_grad
=
True
)
self
.
bn
=
nn
.
BatchNorm2d
(
channels
,
affine
=
False
)
def
forward
(
self
,
pprev
,
prev
):
prev_nodes_out
=
[
pprev
,
prev
]
nodes_used_mask
=
torch
.
zeros
(
self
.
num_nodes
+
2
,
dtype
=
torch
.
bool
,
device
=
prev
.
device
)
for
i
in
range
(
self
.
num_nodes
):
node_out
,
mask
=
self
.
nodes
[
i
](
prev_nodes_out
)
nodes_used_mask
[:
mask
.
size
(
0
)]
|=
mask
.
to
(
prev
.
device
)
# NOTE: which device should we put mask on?
prev_nodes_out
.
append
(
node_out
)
unused_nodes
=
torch
.
cat
([
out
for
used
,
out
in
zip
(
nodes_used_mask
,
prev_nodes_out
)
if
not
used
],
1
)
unused_nodes
=
F
.
relu
(
unused_nodes
)
conv_weight
=
self
.
final_conv_w
[:,
~
nodes_used_mask
,
:,
:,
:]
conv_weight
=
conv_weight
.
view
(
conv_weight
.
size
(
0
),
-
1
,
1
,
1
)
out
=
F
.
conv2d
(
unused_nodes
,
conv_weight
)
return
prev
,
self
.
bn
(
out
)
class
SpaceWithMutableScope
(
nn
.
Module
):
def
__init__
(
self
,
test_case
,
num_layers
=
4
,
num_nodes
=
5
,
channels
=
16
,
in_channels
=
3
,
num_classes
=
10
):
super
().
__init__
()
self
.
test_case
=
test_case
self
.
num_layers
=
num_layers
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
channels
,
3
,
1
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
channels
)
)
self
.
layers
=
nn
.
ModuleList
()
for
_
in
range
(
self
.
num_layers
+
2
):
self
.
layers
.
append
(
Layer
(
num_nodes
,
channels
))
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
dense
=
nn
.
Linear
(
channels
,
num_classes
)
def
forward
(
self
,
x
):
prev
=
cur
=
self
.
stem
(
x
)
for
layer
in
self
.
layers
:
prev
,
cur
=
layer
(
prev
,
cur
)
cur
=
self
.
gap
(
F
.
relu
(
cur
)).
view
(
x
.
size
(
0
),
-
1
)
return
self
.
dense
(
cur
)
src/sdk/pynni/tests/models/pytorch_models/naive.py
0 → 100644
View file @
b49b38f8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
class
NaiveSearchSpace
(
nn
.
Module
):
def
__init__
(
self
,
test_case
):
super
().
__init__
()
self
.
test_case
=
test_case
self
.
conv1
=
LayerChoice
([
nn
.
Conv2d
(
3
,
6
,
3
,
padding
=
1
),
nn
.
Conv2d
(
3
,
6
,
5
,
padding
=
2
)])
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
conv2
=
LayerChoice
([
nn
.
Conv2d
(
6
,
16
,
3
,
padding
=
1
),
nn
.
Conv2d
(
6
,
16
,
5
,
padding
=
2
)],
return_mask
=
True
)
self
.
conv3
=
nn
.
Conv2d
(
16
,
16
,
1
)
self
.
skipconnect
=
InputChoice
(
n_candidates
=
1
)
self
.
skipconnect2
=
InputChoice
(
n_candidates
=
2
,
return_mask
=
True
)
self
.
bn
=
nn
.
BatchNorm2d
(
16
)
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
fc
=
nn
.
Linear
(
16
,
10
)
def
forward
(
self
,
x
):
bs
=
x
.
size
(
0
)
x
=
self
.
pool
(
F
.
relu
(
self
.
conv1
(
x
)))
x0
,
mask
=
self
.
conv2
(
x
)
self
.
test_case
.
assertEqual
(
mask
.
size
(),
torch
.
Size
([
2
]))
x1
=
F
.
relu
(
self
.
conv3
(
x0
))
_
,
mask
=
self
.
skipconnect2
([
x0
,
x1
])
x0
=
self
.
skipconnect
([
x0
])
if
x0
is
not
None
:
x1
+=
x0
x
=
self
.
pool
(
self
.
bn
(
x1
))
self
.
test_case
.
assertEqual
(
mask
.
size
(),
torch
.
Size
([
2
]))
x
=
self
.
gap
(
x
).
view
(
bs
,
-
1
)
x
=
self
.
fc
(
x
)
return
x
src/sdk/pynni/tests/models/pytorch_models/nested.py
0 → 100644
View file @
b49b38f8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
class
MutableOp
(
nn
.
Module
):
def
__init__
(
self
,
kernel_size
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
120
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
nested_mutable
=
InputChoice
(
n_candidates
=
10
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
NestedSpace
(
nn
.
Module
):
# this doesn't pass tests
def
__init__
(
self
,
test_case
):
super
().
__init__
()
self
.
test_case
=
test_case
self
.
conv1
=
LayerChoice
([
MutableOp
(
3
),
MutableOp
(
5
)])
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
fc1
=
nn
.
Linear
(
120
,
10
)
def
forward
(
self
,
x
):
bs
=
x
.
size
(
0
)
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
self
.
gap
(
x
).
view
(
bs
,
-
1
)
x
=
self
.
fc
(
x
)
return
x
src/sdk/pynni/tests/test_nas.py
0 → 100644
View file @
b49b38f8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
importlib
import
os
import
sys
from
unittest
import
TestCase
,
main
import
torch
import
torch.nn
as
nn
from
nni.nas.pytorch.classic_nas
import
get_and_apply_next_architecture
from
nni.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.enas
import
EnasMutator
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.random
import
RandomMutator
from
nni.nas.pytorch.utils
import
_reset_global_mutable_counting
class
NasTestCase
(
TestCase
):
def
setUp
(
self
):
self
.
default_input_size
=
[
3
,
32
,
32
]
self
.
model_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"models"
)
sys
.
path
.
append
(
self
.
model_path
)
self
.
model_module
=
importlib
.
import_module
(
"pytorch_models"
)
self
.
default_cls
=
[
self
.
model_module
.
NaiveSearchSpace
,
self
.
model_module
.
SpaceWithMutableScope
]
self
.
cuda_test
=
[
0
]
if
torch
.
cuda
.
is_available
():
self
.
cuda_test
.
append
(
1
)
if
torch
.
cuda
.
device_count
()
>
1
:
self
.
cuda_test
.
append
(
torch
.
cuda
.
device_count
())
def
tearDown
(
self
):
sys
.
path
.
remove
(
self
.
model_path
)
def
iterative_sample_and_forward
(
self
,
model
,
mutator
=
None
,
input_size
=
None
,
n_iters
=
20
,
test_backward
=
True
,
use_cuda
=
False
):
if
input_size
is
None
:
input_size
=
self
.
default_input_size
# support pytorch only
input_size
=
[
8
if
use_cuda
else
2
]
+
input_size
# at least 2 samples to enable batch norm
for
_
in
range
(
n_iters
):
for
param
in
model
.
parameters
():
param
.
grad
=
None
if
mutator
is
not
None
:
mutator
.
reset
()
x
=
torch
.
randn
(
input_size
)
if
use_cuda
:
x
=
x
.
cuda
()
y
=
torch
.
sum
(
model
(
x
))
if
test_backward
:
y
.
backward
()
def
default_mutator_test_pipeline
(
self
,
mutator_cls
):
for
model_cls
in
self
.
default_cls
:
for
cuda_test
in
self
.
cuda_test
:
_reset_global_mutable_counting
()
model
=
model_cls
(
self
)
mutator
=
mutator_cls
(
model
)
if
cuda_test
:
model
.
cuda
()
mutator
.
cuda
()
if
cuda_test
>
1
:
model
=
nn
.
DataParallel
(
model
)
self
.
iterative_sample_and_forward
(
model
,
mutator
,
use_cuda
=
cuda_test
)
_reset_global_mutable_counting
()
model_fixed
=
model_cls
(
self
)
if
cuda_test
:
model_fixed
.
cuda
()
if
cuda_test
>
1
:
model_fixed
=
nn
.
DataParallel
(
model_fixed
)
with
torch
.
no_grad
():
arc
=
mutator
.
export
()
apply_fixed_architecture
(
model_fixed
,
arc
)
self
.
iterative_sample_and_forward
(
model_fixed
,
n_iters
=
1
,
use_cuda
=
cuda_test
)
def
test_random_mutator
(
self
):
self
.
default_mutator_test_pipeline
(
RandomMutator
)
def
test_enas_mutator
(
self
):
self
.
default_mutator_test_pipeline
(
EnasMutator
)
def
test_darts_mutator
(
self
):
# DARTS doesn't support DataParallel. To be fixed.
self
.
cuda_test
=
[
t
for
t
in
self
.
cuda_test
if
t
<=
1
]
self
.
default_mutator_test_pipeline
(
DartsMutator
)
def
test_apply_twice
(
self
):
model
=
self
.
model_module
.
NaiveSearchSpace
(
self
)
with
self
.
assertRaises
(
RuntimeError
):
for
_
in
range
(
2
):
RandomMutator
(
model
)
def
test_nested_space
(
self
):
model
=
self
.
model_module
.
NestedSpace
(
self
)
with
self
.
assertRaises
(
RuntimeError
):
RandomMutator
(
model
)
def
test_classic_nas
(
self
):
for
model_cls
in
self
.
default_cls
:
model
=
model_cls
(
self
)
get_and_apply_next_architecture
(
model
)
self
.
iterative_sample_and_forward
(
model
)
if
__name__
==
'__main__'
:
main
()
test/pipelines-it-local-windows.yml
View file @
b49b38f8
...
...
@@ -8,7 +8,7 @@ jobs:
-
script
:
|
python -m pip install scikit-learn==0.20.0 --user
python -m pip install keras==2.1.6 --user
python -m pip install
https://download.pytorch.org/whl/cu90/torch-0.4.1-cp36-cp36m-win_amd64.wh
l --user
python -m pip install
torch===1.2.0 torchvision===0.4.1 -f https://download.pytorch.org/whl/torch_stable.htm
l --user
python -m pip install torchvision --user
python -m pip install tensorflow-gpu==1.11.0 --user
displayName
:
'
Install
dependencies
for
integration
tests'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment