Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
52ee9636
Commit
52ee9636
authored
Dec 07, 2018
by
Toby Boyd
Browse files
Merge branch 'cifar_keras' of github.com:tensorflow/models into cifar_keras
parents
1b3c9ba6
87c0e09d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
14 deletions
+9
-14
official/resnet/keras/keras_cifar_main.py
official/resnet/keras/keras_cifar_main.py
+4
-2
official/resnet/keras/keras_imagenet_main.py
official/resnet/keras/keras_imagenet_main.py
+5
-12
No files found.
official/resnet/keras/keras_cifar_main.py
View file @
52ee9636
...
...
@@ -83,7 +83,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
# (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
# ]
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
(
0.1
,
91
),
(
0.01
,
136
),
(
0.001
,
182
)
(
0.1
,
91
),
(
0.01
,
136
),
(
0.001
,
182
)
]
BASE_LEARNING_RATE
=
0.1
...
...
@@ -302,6 +302,8 @@ def run_cifar_with_keras(flags_obj):
lr_callback
,
tesorboard_callback
],
validation_steps
=
num_eval_steps
,
validation_data
=
eval_input_dataset
,
verbose
=
1
)
eval_output
=
model
.
evaluate
(
eval_input_dataset
,
...
...
official/resnet/keras/keras_imagenet_main.py
View file @
52ee9636
...
...
@@ -189,15 +189,6 @@ def run_imagenet_with_keras(flags_obj):
Raises:
ValueError: If fp16 is passed as it is not currently supported.
"""
# Set all random seeds to fixed values.
import
random
import
numpy
as
np
seed
=
87654321
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
tf
.
random
.
set_random_seed
(
seed
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'fp16'
:
raise
ValueError
(
'dtype fp16 is not supported in Keras. Use the default '
...
...
@@ -276,8 +267,8 @@ def run_imagenet_with_keras(flags_obj):
time_callback
=
TimeHistory
(
flags_obj
.
batch_size
)
tesorboard_callback
=
tf
.
keras
.
callbacks
.
TensorBoard
(
log_dir
=
flags_obj
.
model_dir
,
update_freq
=
"batch"
)
# Add this if want per batch logging.
log_dir
=
flags_obj
.
model_dir
)
#
update_freq="batch") # Add this if want per batch logging.
lr_callback
=
LearningRateBatchScheduler
(
learning_rate_schedule
,
...
...
@@ -295,6 +286,8 @@ def run_imagenet_with_keras(flags_obj):
lr_callback
,
tesorboard_callback
],
validation_steps
=
num_eval_steps
,
validation_data
=
eval_input_dataset
,
verbose
=
1
)
eval_output
=
model
.
evaluate
(
eval_input_dataset
,
...
...
@@ -308,6 +301,6 @@ def main(_):
if
__name__
==
'__main__'
:
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
imagenet_main
.
define_imagenet_flags
()
absl_app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment