Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
51fc02ae
Commit
51fc02ae
authored
Dec 20, 2018
by
Shining Sun
Browse files
Add skip_eval flag and change to new optimizer
parent
6f881f77
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
7 deletions
+11
-7
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+4
-3
official/resnet/keras/keras_common.py
official/resnet/keras/keras_common.py
+3
-1
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+4
-3
No files found.
official/resnet/keras/keras_cifar_main.py
View file @
51fc02ae
...
@@ -179,9 +179,10 @@ def run(flags_obj):
...
@@ -179,9 +179,10 @@ def run(flags_obj):
validation_data
=
eval_input_dataset
,
validation_data
=
eval_input_dataset
,
verbose
=
1
)
verbose
=
1
)
eval_output
=
model
.
evaluate
(
eval_input_dataset
,
if
not
flags_obj
.
skip_eval
:
steps
=
num_eval_steps
,
eval_output
=
model
.
evaluate
(
eval_input_dataset
,
verbose
=
1
)
steps
=
num_eval_steps
,
verbose
=
1
)
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
...
...
official/resnet/keras/keras_common.py
View file @
51fc02ae
...
@@ -105,7 +105,7 @@ def get_optimizer():
...
@@ -105,7 +105,7 @@ def get_optimizer():
learning_rate
=
BASE_LEARNING_RATE
*
FLAGS
.
batch_size
/
256
learning_rate
=
BASE_LEARNING_RATE
*
FLAGS
.
batch_size
/
256
optimizer
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
=
learning_rate
,
momentum
=
0.9
)
optimizer
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
=
learning_rate
,
momentum
=
0.9
)
else
:
else
:
optimizer
=
gradient_descent_v2
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
0.1
,
momentum
=
0.9
)
return
optimizer
return
optimizer
...
@@ -138,6 +138,8 @@ def analyze_fit_and_eval_result(history, eval_output):
...
@@ -138,6 +138,8 @@ def analyze_fit_and_eval_result(history, eval_output):
def
define_keras_flags
():
def
define_keras_flags
():
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
flags
.
DEFINE_boolean
(
name
=
'skip_eval'
,
default
=
False
,
help
=
'Skip evaluation?'
)
flags
.
DEFINE_integer
(
flags
.
DEFINE_integer
(
name
=
"train_steps"
,
default
=
None
,
name
=
"train_steps"
,
default
=
None
,
help
=
"The number of steps to run for training"
)
help
=
"The number of steps to run for training"
)
official/resnet/keras/keras_imagenet_main.py
View file @
51fc02ae
...
@@ -172,9 +172,10 @@ def run_imagenet_with_keras(flags_obj):
...
@@ -172,9 +172,10 @@ def run_imagenet_with_keras(flags_obj):
validation_data
=
eval_input_dataset
,
validation_data
=
eval_input_dataset
,
verbose
=
1
)
verbose
=
1
)
eval_output
=
model
.
evaluate
(
eval_input_dataset
,
if
not
flags_obj
.
skip_eval
:
steps
=
num_eval_steps
,
eval_output
=
model
.
evaluate
(
eval_input_dataset
,
verbose
=
1
)
steps
=
num_eval_steps
,
verbose
=
1
)
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
stats
=
keras_common
.
analyze_fit_and_eval_result
(
history
,
eval_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment