Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6c1fe5c8
Commit
6c1fe5c8
authored
Nov 26, 2019
by
Yuge Zhang
Committed by
QuanluZhang
Nov 26, 2019
Browse files
Fix minor issues in DARTS and add a simple CIFAR10 example (with random mutator) (#1776)
parent
9aed500d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
125 additions
and
16 deletions
+125
-16
examples/nas/darts/model.py
examples/nas/darts/model.py
+2
-2
examples/nas/darts/ops.py
examples/nas/darts/ops.py
+3
-6
examples/nas/darts/retrain.py
examples/nas/darts/retrain.py
+1
-1
examples/nas/naive/.gitignore
examples/nas/naive/.gitignore
+1
-0
examples/nas/naive/train.py
examples/nas/naive/train.py
+72
-0
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
+12
-3
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+4
-2
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+4
-2
src/sdk/pynni/nni/nas/pytorch/random/__init__.py
src/sdk/pynni/nni/nas/pytorch/random/__init__.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/random/mutator.py
src/sdk/pynni/nni/nas/pytorch/random/mutator.py
+25
-0
No files found.
examples/nas/darts/model.py
View file @
6c1fe5c8
...
...
@@ -51,7 +51,7 @@ class Node(nn.Module):
ops
.
DilConv
(
channels
,
channels
,
5
,
stride
,
4
,
2
,
affine
=
False
)
],
key
=
choice_keys
[
-
1
]))
self
.
drop_path
=
ops
.
DropPath
_
()
self
.
drop_path
=
ops
.
DropPath
()
self
.
input_switch
=
mutables
.
InputChoice
(
choose_from
=
choice_keys
,
n_chosen
=
2
,
key
=
"{}_switch"
.
format
(
node_id
))
def
forward
(
self
,
prev_nodes
):
...
...
@@ -153,5 +153,5 @@ class CNN(nn.Module):
def
drop_path_prob
(
self
,
p
):
for
module
in
self
.
modules
():
if
isinstance
(
module
,
ops
.
DropPath
_
):
if
isinstance
(
module
,
ops
.
DropPath
):
module
.
p
=
p
examples/nas/darts/ops.py
View file @
6c1fe5c8
...
...
@@ -2,10 +2,10 @@ import torch
import
torch.nn
as
nn
class
DropPath
_
(
nn
.
Module
):
class
DropPath
(
nn
.
Module
):
def
__init__
(
self
,
p
=
0.
):
"""
Drop
P
ath
is inplace module
.
Drop
p
ath
with probability
.
Parameters
----------
...
...
@@ -15,15 +15,12 @@ class DropPath_(nn.Module):
super
().
__init__
()
self
.
p
=
p
def
extra_repr
(
self
):
return
'p={}, inplace'
.
format
(
self
.
p
)
def
forward
(
self
,
x
):
if
self
.
training
and
self
.
p
>
0.
:
keep_prob
=
1.
-
self
.
p
# per data point mask
mask
=
torch
.
zeros
((
x
.
size
(
0
),
1
,
1
,
1
),
device
=
x
.
device
).
bernoulli_
(
keep_prob
)
x
.
div_
(
keep_prob
).
mul_
(
mask
)
return
x
/
keep_prob
*
mask
return
x
...
...
examples/nas/darts/retrain.py
View file @
6c1fe5c8
...
...
@@ -33,7 +33,7 @@ def train(config, train_loader, model, optimizer, criterion, epoch):
losses
=
AverageMeter
(
"losses"
)
cur_step
=
epoch
*
len
(
train_loader
)
cur_lr
=
optimizer
.
param_groups
[
0
][
'
lr
'
]
cur_lr
=
optimizer
.
param_groups
[
0
][
"
lr
"
]
logger
.
info
(
"Epoch %d LR %.6f"
,
epoch
,
cur_lr
)
writer
.
add_scalar
(
"lr"
,
cur_lr
,
global_step
=
cur_step
)
...
...
examples/nas/naive/.gitignore
0 → 100644
View file @
6c1fe5c8
checkpoint.json
examples/nas/naive/train.py
0 → 100644
View file @
6c1fe5c8
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torchvision
import
torchvision.transforms
as
transforms
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.pytorch.darts
import
DartsTrainer
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
conv1
=
LayerChoice
([
nn
.
Conv2d
(
3
,
6
,
3
,
padding
=
1
),
nn
.
Conv2d
(
3
,
6
,
5
,
padding
=
2
)])
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
conv2
=
LayerChoice
([
nn
.
Conv2d
(
6
,
16
,
3
,
padding
=
1
),
nn
.
Conv2d
(
6
,
16
,
5
,
padding
=
2
)])
self
.
conv3
=
nn
.
Conv2d
(
16
,
16
,
1
)
self
.
skipconnect
=
InputChoice
(
n_candidates
=
1
)
self
.
bn
=
nn
.
BatchNorm2d
(
16
)
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
4
)
self
.
fc1
=
nn
.
Linear
(
16
*
4
*
4
,
120
)
self
.
fc2
=
nn
.
Linear
(
120
,
84
)
self
.
fc3
=
nn
.
Linear
(
84
,
10
)
def
forward
(
self
,
x
):
bs
=
x
.
size
(
0
)
x
=
self
.
pool
(
F
.
relu
(
self
.
conv1
(
x
)))
x0
=
F
.
relu
(
self
.
conv2
(
x
))
x1
=
F
.
relu
(
self
.
conv3
(
x0
))
x0
=
self
.
skipconnect
([
x0
])
if
x0
is
not
None
:
x1
+=
x0
x
=
self
.
pool
(
self
.
bn
(
x1
))
x
=
self
.
gap
(
x
).
view
(
bs
,
-
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
F
.
relu
(
self
.
fc2
(
x
))
x
=
self
.
fc3
(
x
)
return
x
def
accuracy
(
output
,
target
):
batch_size
=
target
.
size
(
0
)
_
,
predicted
=
torch
.
max
(
output
.
data
,
1
)
return
{
"acc1"
:
(
predicted
==
target
).
sum
().
item
()
/
batch_size
}
if
__name__
==
"__main__"
:
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.5
,
0.5
,
0.5
),
(
0.5
,
0.5
,
0.5
))])
dataset_train
=
torchvision
.
datasets
.
CIFAR10
(
root
=
"./data"
,
train
=
True
,
download
=
True
,
transform
=
transform
)
dataset_valid
=
torchvision
.
datasets
.
CIFAR10
(
root
=
"./data"
,
train
=
False
,
download
=
True
,
transform
=
transform
)
net
=
Net
()
criterion
=
nn
.
CrossEntropyLoss
()
optimizer
=
optim
.
SGD
(
net
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
)
trainer
=
DartsTrainer
(
net
,
loss
=
criterion
,
metrics
=
accuracy
,
optimizer
=
optimizer
,
num_epochs
=
2
,
dataset_train
=
dataset_train
,
dataset_valid
=
dataset_valid
,
batch_size
=
64
,
log_frequency
=
10
)
trainer
.
train
()
trainer
.
export
(
"checkpoint.json"
)
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
View file @
6c1fe5c8
import
logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -5,6 +7,8 @@ import torch.nn.functional as F
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
_logger
=
logging
.
getLogger
(
__name__
)
class
DartsMutator
(
Mutator
):
def
__init__
(
self
,
model
):
...
...
@@ -36,9 +40,14 @@ class DartsMutator(Mutator):
edges_max
[
mutable
.
key
]
=
max_val
result
[
mutable
.
key
]
=
F
.
one_hot
(
index
,
num_classes
=
mutable
.
length
).
view
(
-
1
).
bool
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
InputChoice
):
weights
=
torch
.
tensor
([
edges_max
.
get
(
src_key
,
0.
)
for
src_key
in
mutable
.
choose_from
])
# pylint: disable=not-callable
_
,
topk_edge_indices
=
torch
.
topk
(
weights
,
mutable
.
n_chosen
or
mutable
.
n_candidates
)
if
isinstance
(
mutable
,
InputChoice
)
and
mutable
.
n_chosen
is
not
None
:
weights
=
[]
for
src_key
in
mutable
.
choose_from
:
if
src_key
not
in
edges_max
:
_logger
.
warning
(
"InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs."
,
mutable
.
key
)
weights
.
append
(
edges_max
.
get
(
src_key
,
0.
))
weights
=
torch
.
tensor
(
weights
)
# pylint: disable=not-callable
_
,
topk_edge_indices
=
torch
.
topk
(
weights
,
mutable
.
n_chosen
)
selected_multihot
=
[]
for
i
,
src_key
in
enumerate
(
mutable
.
choose_from
):
if
i
not
in
topk_edge_indices
and
src_key
in
result
:
...
...
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
6c1fe5c8
...
...
@@ -62,7 +62,8 @@ class Mutable(nn.Module):
def
_check_built
(
self
):
if
not
hasattr
(
self
,
"mutator"
):
raise
ValueError
(
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details."
.
format
(
self
))
def
__repr__
(
self
):
...
...
@@ -179,7 +180,8 @@ class InputChoice(Mutable):
optional_input_list
=
optional_inputs
if
isinstance
(
optional_inputs
,
dict
):
optional_input_list
=
[
optional_inputs
[
tag
]
for
tag
in
self
.
choose_from
]
assert
isinstance
(
optional_input_list
,
list
),
"Optional input list must be a list"
assert
isinstance
(
optional_input_list
,
list
),
\
"Optional input list must be a list, not a {}."
.
format
(
type
(
optional_input_list
))
assert
len
(
optional_inputs
)
==
self
.
n_candidates
,
\
"Length of the input list must be equal to number of candidates."
out
,
mask
=
self
.
mutator
.
on_forward_input_choice
(
self
,
optional_input_list
)
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
6c1fe5c8
...
...
@@ -76,7 +76,8 @@ class Mutator(BaseMutator):
return
op
(
*
inputs
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
len
(
mutable
.
choices
)
assert
len
(
mask
)
==
len
(
mutable
.
choices
),
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
.
choices
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
inputs
)
for
choice
in
mutable
.
choices
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
...
...
@@ -98,7 +99,8 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor
"""
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
mutable
.
n_candidates
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
out
=
self
.
_select_with_mask
(
lambda
x
:
x
,
[(
t
,)
for
t
in
tensor_list
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
...
...
src/sdk/pynni/nni/nas/pytorch/random/__init__.py
0 → 100644
View file @
6c1fe5c8
from
.mutator
import
RandomMutator
\ No newline at end of file
src/sdk/pynni/nni/nas/pytorch/random/mutator.py
0 → 100644
View file @
6c1fe5c8
import
torch
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
class
RandomMutator
(
Mutator
):
def
sample_search
(
self
):
result
=
dict
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
gen_index
=
torch
.
randint
(
high
=
mutable
.
length
,
size
=
(
1
,
))
result
[
mutable
.
key
]
=
F
.
one_hot
(
gen_index
,
num_classes
=
mutable
.
length
).
view
(
-
1
).
bool
()
elif
isinstance
(
mutable
,
InputChoice
):
if
mutable
.
n_chosen
is
None
:
result
[
mutable
.
key
]
=
torch
.
randint
(
high
=
2
,
size
=
(
mutable
.
n_candidates
,)).
view
(
-
1
).
bool
()
else
:
perm
=
torch
.
randperm
(
mutable
.
n_candidates
)
mask
=
[
i
in
perm
[:
mutable
.
n_chosen
]
for
i
in
range
(
mutable
.
n_candidates
)]
result
[
mutable
.
key
]
=
torch
.
tensor
(
mask
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
return
result
def
sample_final
(
self
):
return
self
.
sample_search
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment