Commit 52ee9636 authored by Toby Boyd's avatar Toby Boyd
Browse files

Merge branch 'cifar_keras' of github.com:tensorflow/models into cifar_keras

parents 1b3c9ba6 87c0e09d
...@@ -83,7 +83,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -83,7 +83,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
# (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) # (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
# ] # ]
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182) (0.1, 91), (0.01, 136), (0.001, 182)
] ]
BASE_LEARNING_RATE = 0.1 BASE_LEARNING_RATE = 0.1
...@@ -302,6 +302,8 @@ def run_cifar_with_keras(flags_obj): ...@@ -302,6 +302,8 @@ def run_cifar_with_keras(flags_obj):
lr_callback, lr_callback,
tesorboard_callback tesorboard_callback
], ],
validation_steps=num_eval_steps,
validation_data=eval_input_dataset,
verbose=1) verbose=1)
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
......
...@@ -189,15 +189,6 @@ def run_imagenet_with_keras(flags_obj): ...@@ -189,15 +189,6 @@ def run_imagenet_with_keras(flags_obj):
Raises: Raises:
ValueError: If fp16 is passed as it is not currently supported. ValueError: If fp16 is passed as it is not currently supported.
""" """
# Set all random seeds to fixed values.
import random
import numpy as np
seed = 87654321
random.seed(seed)
np.random.seed(seed)
tf.random.set_random_seed(seed)
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' raise ValueError('dtype fp16 is not supported in Keras. Use the default '
...@@ -276,8 +267,8 @@ def run_imagenet_with_keras(flags_obj): ...@@ -276,8 +267,8 @@ def run_imagenet_with_keras(flags_obj):
time_callback = TimeHistory(flags_obj.batch_size) time_callback = TimeHistory(flags_obj.batch_size)
tesorboard_callback = tf.keras.callbacks.TensorBoard( tesorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir, log_dir=flags_obj.model_dir)
update_freq="batch") # Add this if want per batch logging. # update_freq="batch") # Add this if want per batch logging.
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule, learning_rate_schedule,
...@@ -295,6 +286,8 @@ def run_imagenet_with_keras(flags_obj): ...@@ -295,6 +286,8 @@ def run_imagenet_with_keras(flags_obj):
lr_callback, lr_callback,
tesorboard_callback tesorboard_callback
], ],
validation_steps=num_eval_steps,
validation_data=eval_input_dataset,
verbose=1) verbose=1)
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
...@@ -308,6 +301,6 @@ def main(_): ...@@ -308,6 +301,6 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.DEBUG) tf.logging.set_verbosity(tf.logging.INFO)
imagenet_main.define_imagenet_flags() imagenet_main.define_imagenet_flags()
absl_app.run(main) absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment