Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
58873c46
Unverified
Commit
58873c46
authored
Oct 19, 2020
by
Duong Nhu
Committed by
GitHub
Oct 19, 2020
Browse files
Parameterized training options for EsTrainer implementation in tensorflow (#2953)
parent
d5036857
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
98 additions
and
53 deletions
+98
-53
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
+98
-53
No files found.
src/sdk/pynni/nni/nas/tensorflow/enas/trainer.py
View file @
58873c46
...
@@ -13,21 +13,29 @@ from .mutator import EnasMutator
...
@@ -13,21 +13,29 @@ from .mutator import EnasMutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
log_frequency
=
100
entropy_weight
=
0.0001
skip_weight
=
0.8
baseline_decay
=
0.999
child_steps
=
500
mutator_lr
=
0.00035
mutator_steps
=
50
mutator_steps_aggregate
=
20
aux_weight
=
0.4
test_arc_per_epoch
=
1
class
EnasTrainer
:
class
EnasTrainer
:
def
__init__
(
self
,
model
,
loss
,
metrics
,
reward_function
,
optimizer
,
batch_size
,
num_epochs
,
def
__init__
(
dataset_train
,
dataset_valid
):
self
,
model
,
loss
,
metrics
,
reward_function
,
optimizer
,
batch_size
,
num_epochs
,
dataset_train
,
dataset_valid
,
log_frequency
=
100
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
child_steps
=
500
,
mutator_lr
=
0.00035
,
mutator_steps
=
50
,
mutator_steps_aggregate
=
20
,
aux_weight
=
0.4
,
test_arc_per_epoch
=
1
,
):
self
.
model
=
model
self
.
model
=
model
self
.
loss
=
loss
self
.
loss
=
loss
self
.
metrics
=
metrics
self
.
metrics
=
metrics
...
@@ -42,11 +50,21 @@ class EnasTrainer:
...
@@ -42,11 +50,21 @@ class EnasTrainer:
self
.
valid_set
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
x
[
split
:],
y
[
split
:]))
self
.
valid_set
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
x
[
split
:],
y
[
split
:]))
self
.
test_set
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
dataset_valid
)
self
.
test_set
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
dataset_valid
)
self
.
mutator
=
EnasMutator
(
model
)
self
.
log_frequency
=
log_frequency
self
.
mutator_optim
=
Adam
(
learning_rate
=
mutator_lr
)
self
.
entropy_weight
=
entropy_weight
self
.
skip_weight
=
skip_weight
self
.
baseline_decay
=
baseline_decay
self
.
child_steps
=
child_steps
self
.
mutator_lr
=
mutator_lr
self
.
mutator_steps
=
mutator_steps
self
.
mutator_steps_aggregate
=
mutator_steps_aggregate
self
.
aux_weight
=
aux_weight
self
.
test_arc_per_epoch
=
test_arc_per_epoch
self
.
baseline
=
0.
self
.
mutator
=
EnasMutator
(
model
)
self
.
mutator_optim
=
Adam
(
learning_rate
=
self
.
mutator_lr
)
self
.
baseline
=
0.0
def
train
(
self
,
validate
=
True
):
def
train
(
self
,
validate
=
True
):
for
epoch
in
range
(
self
.
num_epochs
):
for
epoch
in
range
(
self
.
num_epochs
):
...
@@ -58,14 +76,13 @@ class EnasTrainer:
...
@@ -58,14 +76,13 @@ class EnasTrainer:
def
validate
(
self
):
def
validate
(
self
):
self
.
validate_one_epoch
(
-
1
)
self
.
validate_one_epoch
(
-
1
)
def
train_one_epoch
(
self
,
epoch
):
def
train_one_epoch
(
self
,
epoch
):
train_loader
,
valid_loader
=
self
.
_create_train_loader
()
train_loader
,
valid_loader
=
self
.
_create_train_loader
()
# Sample model and train
# Sample model and train
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
for
step
in
range
(
1
,
child_steps
+
1
):
for
step
in
range
(
1
,
self
.
child_steps
+
1
):
x
,
y
=
next
(
train_loader
)
x
,
y
=
next
(
train_loader
)
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
...
@@ -75,64 +92,88 @@ class EnasTrainer:
...
@@ -75,64 +92,88 @@ class EnasTrainer:
logits
,
aux_logits
=
logits
logits
,
aux_logits
=
logits
aux_loss
=
self
.
loss
(
aux_logits
,
y
)
aux_loss
=
self
.
loss
(
aux_logits
,
y
)
else
:
else
:
aux_loss
=
0.
aux_loss
=
0.
0
metrics
=
self
.
metrics
(
y
,
logits
)
metrics
=
self
.
metrics
(
y
,
logits
)
loss
=
self
.
loss
(
y
,
logits
)
+
aux_weight
*
aux_loss
loss
=
self
.
loss
(
y
,
logits
)
+
self
.
aux_weight
*
aux_loss
grads
=
tape
.
gradient
(
loss
,
self
.
model
.
trainable_weights
)
grads
=
tape
.
gradient
(
loss
,
self
.
model
.
trainable_weights
)
grads
=
fill_zero_grads
(
grads
,
self
.
model
.
trainable_weights
)
grads
=
fill_zero_grads
(
grads
,
self
.
model
.
trainable_weights
)
grads
,
_
=
tf
.
clip_by_global_norm
(
grads
,
5.0
)
grads
,
_
=
tf
.
clip_by_global_norm
(
grads
,
5.0
)
self
.
optimizer
.
apply_gradients
(
zip
(
grads
,
self
.
model
.
trainable_weights
))
self
.
optimizer
.
apply_gradients
(
zip
(
grads
,
self
.
model
.
trainable_weights
))
metrics
[
'
loss
'
]
=
tf
.
reduce_mean
(
loss
).
numpy
()
metrics
[
"
loss
"
]
=
tf
.
reduce_mean
(
loss
).
numpy
()
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
log_frequency
and
step
%
log_frequency
==
0
:
if
self
.
log_frequency
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Model Epoch [%d/%d] Step [%d/%d] %s"
,
epoch
+
1
,
logger
.
info
(
self
.
num_epochs
,
step
,
child_steps
,
meters
)
"Model Epoch [%d/%d] Step [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
,
self
.
child_steps
,
meters
,
)
# Train sampler (mutator)
# Train sampler (mutator)
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
for
mutator_step
in
range
(
1
,
mutator_steps
+
1
):
for
mutator_step
in
range
(
1
,
self
.
mutator_steps
+
1
):
grads_list
=
[]
grads_list
=
[]
for
step
in
range
(
1
,
mutator_steps_aggregate
+
1
):
for
step
in
range
(
1
,
self
.
mutator_steps_aggregate
+
1
):
with
tf
.
GradientTape
()
as
tape
:
with
tf
.
GradientTape
()
as
tape
:
x
,
y
=
next
(
valid_loader
)
x
,
y
=
next
(
valid_loader
)
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
,
training
=
False
)
logits
=
self
.
model
(
x
,
training
=
False
)
metrics
=
self
.
metrics
(
y
,
logits
)
metrics
=
self
.
metrics
(
y
,
logits
)
reward
=
self
.
reward_function
(
y
,
logits
)
+
entropy_weight
*
self
.
mutator
.
sample_entropy
reward
=
(
self
.
baseline
=
self
.
baseline
*
baseline_decay
+
reward
*
(
1
-
baseline_decay
)
self
.
reward_function
(
y
,
logits
)
+
self
.
entropy_weight
*
self
.
mutator
.
sample_entropy
)
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
loss
=
self
.
mutator
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
loss
=
self
.
mutator
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
loss
+=
skip_weight
*
self
.
mutator
.
sample_skip_penalty
loss
+=
self
.
skip_weight
*
self
.
mutator
.
sample_skip_penalty
meters
.
update
({
meters
.
update
(
'reward'
:
reward
,
{
'loss'
:
tf
.
reduce_mean
(
loss
).
numpy
(),
"reward"
:
reward
,
'ent'
:
self
.
mutator
.
sample_entropy
.
numpy
(),
"loss"
:
tf
.
reduce_mean
(
loss
).
numpy
(),
'log_prob'
:
self
.
mutator
.
sample_log_prob
.
numpy
(),
"ent"
:
self
.
mutator
.
sample_entropy
.
numpy
(),
'baseline'
:
self
.
baseline
,
"log_prob"
:
self
.
mutator
.
sample_log_prob
.
numpy
(),
'skip'
:
self
.
mutator
.
sample_skip_penalty
,
"baseline"
:
self
.
baseline
,
})
"skip"
:
self
.
mutator
.
sample_skip_penalty
,
}
cur_step
=
step
+
(
mutator_step
-
1
)
*
mutator_steps_aggregate
)
if
log_frequency
and
cur_step
%
log_frequency
==
0
:
logger
.
info
(
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
cur_step
=
step
+
(
mutator_step
-
1
)
*
self
.
mutator_steps_aggregate
mutator_step
,
mutator_steps
,
step
,
mutator_steps_aggregate
,
if
self
.
log_frequency
and
cur_step
%
self
.
log_frequency
==
0
:
meters
)
logger
.
info
(
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
mutator_step
,
self
.
mutator_steps
,
step
,
self
.
mutator_steps_aggregate
,
meters
,
)
grads
=
tape
.
gradient
(
loss
,
self
.
mutator
.
trainable_weights
)
grads
=
tape
.
gradient
(
loss
,
self
.
mutator
.
trainable_weights
)
grads
=
fill_zero_grads
(
grads
,
self
.
mutator
.
trainable_weights
)
grads
=
fill_zero_grads
(
grads
,
self
.
mutator
.
trainable_weights
)
grads_list
.
append
(
grads
)
grads_list
.
append
(
grads
)
total_grads
=
[
tf
.
math
.
add_n
(
weight_grads
)
for
weight_grads
in
zip
(
*
grads_list
)]
total_grads
=
[
tf
.
math
.
add_n
(
weight_grads
)
for
weight_grads
in
zip
(
*
grads_list
)
]
total_grads
,
_
=
tf
.
clip_by_global_norm
(
total_grads
,
5.0
)
total_grads
,
_
=
tf
.
clip_by_global_norm
(
total_grads
,
5.0
)
self
.
mutator_optim
.
apply_gradients
(
zip
(
total_grads
,
self
.
mutator
.
trainable_weights
))
self
.
mutator_optim
.
apply_gradients
(
zip
(
total_grads
,
self
.
mutator
.
trainable_weights
)
)
def
validate_one_epoch
(
self
,
epoch
):
def
validate_one_epoch
(
self
,
epoch
):
test_loader
=
self
.
_create_validate_loader
()
test_loader
=
self
.
_create_validate_loader
()
for
arc_id
in
range
(
test_arc_per_epoch
):
for
arc_id
in
range
(
self
.
test_arc_per_epoch
):
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
for
x
,
y
in
test_loader
:
for
x
,
y
in
test_loader
:
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
...
@@ -141,13 +182,17 @@ class EnasTrainer:
...
@@ -141,13 +182,17 @@ class EnasTrainer:
logits
,
_
=
logits
logits
,
_
=
logits
metrics
=
self
.
metrics
(
y
,
logits
)
metrics
=
self
.
metrics
(
y
,
logits
)
loss
=
self
.
loss
(
y
,
logits
)
loss
=
self
.
loss
(
y
,
logits
)
metrics
[
'
loss
'
]
=
tf
.
reduce_mean
(
loss
).
numpy
()
metrics
[
"
loss
"
]
=
tf
.
reduce_mean
(
loss
).
numpy
()
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
logger
.
info
(
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s"
,
logger
.
info
(
epoch
+
1
,
self
.
num_epochs
,
arc_id
+
1
,
test_arc_per_epoch
,
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s"
,
meters
.
summary
())
epoch
+
1
,
self
.
num_epochs
,
arc_id
+
1
,
self
.
test_arc_per_epoch
,
meters
.
summary
(),
)
def
_create_train_loader
(
self
):
def
_create_train_loader
(
self
):
train_set
=
self
.
train_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
)
train_set
=
self
.
train_set
.
shuffle
(
1000000
).
repeat
().
batch
(
self
.
batch_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment