Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5c96b82c
Commit
5c96b82c
authored
Nov 28, 2019
by
Chi Song
Committed by
chicm-ms
Dec 02, 2019
Browse files
[NAS] fix bug on pdarts (#1797)
parent
e9cba778
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
84 additions
and
59 deletions
+84
-59
examples/nas/pdarts/search.py
examples/nas/pdarts/search.py
+8
-4
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+2
-1
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
+55
-44
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+18
-9
No files found.
examples/nas/pdarts/search.py
View file @
5c96b82c
...
...
@@ -27,11 +27,14 @@ if __name__ == "__main__":
parser
=
ArgumentParser
(
"pdarts"
)
parser
.
add_argument
(
'--add_layers'
,
action
=
'append'
,
default
=
[
0
,
6
,
12
],
help
=
'add layers'
)
parser
.
add_argument
(
'--dropped_ops'
,
action
=
'append'
,
default
=
[
3
,
2
,
1
],
help
=
'drop ops'
)
parser
.
add_argument
(
"--nodes"
,
default
=
4
,
type
=
int
)
parser
.
add_argument
(
"--layers"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
"--
init_
layers"
,
default
=
5
,
type
=
int
)
parser
.
add_argument
(
"--batch-size"
,
default
=
64
,
type
=
int
)
parser
.
add_argument
(
"--log-frequency"
,
default
=
1
,
type
=
int
)
parser
.
add_argument
(
"--epochs"
,
default
=
50
,
type
=
int
)
parser
.
add_argument
(
"--unrolled"
,
default
=
False
,
action
=
"store_true"
)
args
=
parser
.
parse_args
()
logger
.
info
(
"loading data"
)
...
...
@@ -48,15 +51,16 @@ if __name__ == "__main__":
logger
.
info
(
"initializing trainer"
)
trainer
=
PdartsTrainer
(
model_creator
,
layers
=
args
.
layers
,
init_
layers
=
args
.
init_
layers
,
metrics
=
lambda
output
,
target
:
accuracy
(
output
,
target
,
topk
=
(
1
,)),
pdarts_num_layers
=
[
0
,
6
,
12
]
,
pdarts_num_to_drop
=
[
3
,
2
,
2
]
,
pdarts_num_layers
=
args
.
add_layers
,
pdarts_num_to_drop
=
args
.
dropped_ops
,
num_epochs
=
args
.
epochs
,
dataset_train
=
dataset_train
,
dataset_valid
=
dataset_valid
,
batch_size
=
args
.
batch_size
,
log_frequency
=
args
.
log_frequency
,
unrolled
=
args
.
unrolled
,
callbacks
=
[
ArchitectureCheckpoint
(
"./checkpoints"
)])
logger
.
info
(
"training"
)
trainer
.
train
()
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
5c96b82c
...
...
@@ -18,10 +18,11 @@ class DartsTrainer(Trainer):
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
Tru
e
):
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
Fals
e
):
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
ctrl_optim
=
torch
.
optim
.
Adam
(
self
.
mutator
.
parameters
(),
arc_learning_rate
,
betas
=
(
0.5
,
0.999
),
weight_decay
=
1.0E-3
)
self
.
unrolled
=
unrolled
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
5c96b82c
...
...
@@ -111,7 +111,7 @@ class Mutator(BaseMutator):
if
"BoolTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
elif
"FloatTensor"
in
mask
.
type
():
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)]
out
=
[
map_fn
(
*
cand
)
*
m
for
cand
,
m
in
zip
(
candidates
,
mask
)
if
m
]
else
:
raise
ValueError
(
"Unrecognized mask"
)
return
out
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/mutator.py
View file @
5c96b82c
...
...
@@ -4,13 +4,18 @@
import
copy
import
numpy
as
np
import
torch.nn.functional
as
F
import
torch
from
torch
import
nn
from
nni.nas.pytorch.darts
import
DartsMutator
from
nni.nas.pytorch.mutables
import
LayerChoice
class
PdartsMutator
(
DartsMutator
):
"""
It works with PdartsTrainer to calculate ops weights,
and drop weights in different PDARTS epochs.
"""
def
__init__
(
self
,
model
,
pdarts_epoch_index
,
pdarts_num_to_drop
,
switches
=
{}):
self
.
pdarts_epoch_index
=
pdarts_epoch_index
...
...
@@ -22,60 +27,66 @@ class PdartsMutator(DartsMutator):
super
(
PdartsMutator
,
self
).
__init__
(
model
)
# this loop go through mutables with different keys,
# it's mainly to update length of choices.
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
switches
=
self
.
switches
.
get
(
mutable
.
key
,
[
True
for
j
in
range
(
mutable
.
length
)])
choices
=
self
.
choices
[
mutable
.
key
]
for
index
in
range
(
len
(
switches
)
-
1
,
-
1
,
-
1
):
if
switches
[
index
]
==
False
:
del
(
mutable
.
choices
[
index
])
mutable
.
length
-=
1
operations_count
=
np
.
sum
(
switches
)
# +1 and -1 are caused by zero operation in darts network
# the zero operation is not in choices list in network, but its weight are in,
# so it needs one more weights and switch for zero.
self
.
choices
[
mutable
.
key
]
=
nn
.
Parameter
(
1.0E-3
*
torch
.
randn
(
operations_count
+
1
))
self
.
switches
[
mutable
.
key
]
=
switches
def
drop_paths
(
self
):
for
key
in
self
.
switches
:
prob
=
F
.
softmax
(
self
.
choices
[
key
],
dim
=-
1
).
data
.
cpu
().
numpy
()
# update LayerChoice instances in model,
# it's physically remove dropped choices operations.
for
module
in
self
.
model
.
modules
():
if
isinstance
(
module
,
LayerChoice
):
switches
=
self
.
switches
.
get
(
module
.
key
)
choices
=
self
.
choices
[
module
.
key
]
if
len
(
module
.
choices
)
>
len
(
choices
):
# from last to first, so that it won't effect previous indexes after removed one.
for
index
in
range
(
len
(
switches
)
-
1
,
-
1
,
-
1
):
if
switches
[
index
]
==
False
:
del
(
module
.
choices
[
index
])
module
.
length
-=
1
def
sample_final
(
self
):
results
=
super
().
sample_final
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
LayerChoice
):
# As some operations are dropped physically,
# so it needs to fill back false to track dropped operations.
trained_result
=
results
[
mutable
.
key
]
trained_index
=
0
switches
=
self
.
switches
[
mutable
.
key
]
result
=
torch
.
Tensor
(
switches
).
bool
()
for
index
in
range
(
len
(
result
)):
if
result
[
index
]:
result
[
index
]
=
trained_result
[
trained_index
]
trained_index
+=
1
results
[
mutable
.
key
]
=
result
return
results
switches
=
self
.
switches
[
key
]
def
drop_paths
(
self
):
"""
This method is called when a PDARTS epoch is finished.
It prepares switches for next epoch.
candidate operations with False switch will be doppped in next epoch.
"""
all_switches
=
copy
.
deepcopy
(
self
.
switches
)
for
key
in
all_switches
:
switches
=
all_switches
[
key
]
idxs
=
[]
for
j
in
range
(
len
(
switches
)):
if
switches
[
j
]:
idxs
.
append
(
j
)
if
self
.
pdarts_epoch_index
==
len
(
self
.
pdarts_num_to_drop
)
-
1
:
# for the last stage, drop all Zero operations
drop
=
self
.
get_min_k_no_zero
(
prob
,
idxs
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
else
:
drop
=
self
.
get_min_k
(
prob
,
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
])
sorted_weights
=
self
.
choices
[
key
].
data
.
cpu
().
numpy
()[:
-
1
]
drop
=
np
.
argsort
(
sorted_weights
)[:
self
.
pdarts_num_to_drop
[
self
.
pdarts_epoch_index
]]
for
idx
in
drop
:
switches
[
idxs
[
idx
]]
=
False
return
self
.
switches
def
get_min_k
(
self
,
input_in
,
k
):
index
=
[]
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
input
)
index
.
append
(
idx
)
return
index
def
get_min_k_no_zero
(
self
,
w_in
,
idxs
,
k
):
w
=
copy
.
deepcopy
(
w_in
)
index
=
[]
if
0
in
idxs
:
zf
=
True
else
:
zf
=
False
if
zf
:
w
=
w
[
1
:]
index
.
append
(
0
)
k
=
k
-
1
for
_
in
range
(
k
):
idx
=
np
.
argmin
(
w
)
w
[
idx
]
=
1
if
zf
:
idx
=
idx
+
1
index
.
append
(
idx
)
return
index
return
all_switches
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
View file @
5c96b82c
...
...
@@ -14,14 +14,22 @@ logger = logging.getLogger(__name__)
class
PdartsTrainer
(
BaseTrainer
):
def
__init__
(
self
,
model_creator
,
layers
,
metrics
,
"""
This trainer implements the PDARTS algorithm.
PDARTS bases on DARTS algorithm, and provides a network growth approach to find deeper and better network.
This class relies on pdarts_num_layers and pdarts_num_to_drop parameters to control how network grows.
pdarts_num_layers means how many layers more than first epoch.
pdarts_num_to_drop means how many candidate operations should be dropped in each epoch.
So that the grew network can in similar size.
"""
def
__init__
(
self
,
model_creator
,
init_layers
,
metrics
,
num_epochs
,
dataset_train
,
dataset_valid
,
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
2
],
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
):
pdarts_num_layers
=
[
0
,
6
,
12
],
pdarts_num_to_drop
=
[
3
,
2
,
1
],
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
unrolled
=
False
):
super
(
PdartsTrainer
,
self
).
__init__
()
self
.
model_creator
=
model_creator
self
.
layers
=
layers
self
.
init_
layers
=
init_
layers
self
.
pdarts_num_layers
=
pdarts_num_layers
self
.
pdarts_num_to_drop
=
pdarts_num_to_drop
self
.
pdarts_epoch
=
len
(
pdarts_num_to_drop
)
...
...
@@ -33,16 +41,17 @@ class PdartsTrainer(BaseTrainer):
"batch_size"
:
batch_size
,
"workers"
:
workers
,
"device"
:
device
,
"log_frequency"
:
log_frequency
"log_frequency"
:
log_frequency
,
"unrolled"
:
unrolled
}
self
.
callbacks
=
callbacks
if
callbacks
is
not
None
else
[]
def
train
(
self
):
layers
=
self
.
layers
switches
=
None
for
epoch
in
range
(
self
.
pdarts_epoch
):
layers
=
self
.
layers
+
self
.
pdarts_num_layers
[
epoch
]
layers
=
self
.
init_
layers
+
self
.
pdarts_num_layers
[
epoch
]
model
,
criterion
,
optim
,
lr_scheduler
=
self
.
model_creator
(
layers
)
self
.
mutator
=
PdartsMutator
(
model
,
epoch
,
self
.
pdarts_num_to_drop
,
switches
)
...
...
@@ -66,7 +75,7 @@ class PdartsTrainer(BaseTrainer):
callback
.
on_epoch_end
(
epoch
)
def
validate
(
self
):
self
.
model
.
validate
()
self
.
trainer
.
validate
()
def
export
(
self
,
file
):
mutator_export
=
self
.
mutator
.
export
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment