Commit cf723306 authored by Shining Sun's avatar Shining Sun
Browse files

bug fixes

parent 3fd9c7fe
......@@ -175,7 +175,8 @@ def run_cifar_with_keras(flags_obj):
optimizer=opt,
metrics=[accuracy],
distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks()
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks(
learning_rate_schedule)
steps_per_epoch = cifar_main._NUM_IMAGES['train'] // flags_obj.batch_size
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
......@@ -187,7 +188,7 @@ def run_cifar_with_keras(flags_obj):
callbacks=[
time_callback,
lr_callback,
tesorboard_callback
tensorboard_callback
],
validation_steps=num_eval_steps,
validation_data=eval_input_dataset,
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Runs a ResNet model on the ImageNet dataset."""
"""Common util functions an classes used by both keras cifar and imagenet."""
from __future__ import absolute_import
from __future__ import division
......@@ -36,6 +36,9 @@ from official.utils.misc import distribution_utils
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
FLAGS = flags.FLAGS
class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models."""
......@@ -122,7 +125,7 @@ def get_optimizer_loss_and_metrics():
opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
# TF Optimizer:
# learning_rate = BASE_LEARNING_RATE * flags_obj.batch_size / 256
# learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256
# opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy'
......@@ -131,26 +134,26 @@ def get_optimizer_loss_and_metrics():
def get_dist_strategy():
if flags_obj.num_gpus == 1 and flags_obj.dist_strat_off:
if FLAGS.num_gpus == 1 and FLAGS.dist_strat_off:
print('Not using distribution strategies.')
strategy = None
else:
strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus)
num_gpus=FLAGS.num_gpus)
return strategy
def get_fit_callbacks():
time_callback = keras_common.TimeHistory(flags_obj.batch_size)
def get_fit_callbacks(learning_rate_schedule_fn):
time_callback = TimeHistory(FLAGS.batch_size)
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir)
log_dir=FLAGS.model_dir)
#update_freq="batch") # Add this if want per batch logging.
lr_callback = keras_common.LearningRateBatchScheduler(
learning_rate_schedule,
batch_size=flags_obj.batch_size,
lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn,
batch_size=FLAGS.batch_size,
num_images=imagenet_main._NUM_IMAGES['train'])
return time_callback, tensorboard_callback, lr_callback
......
......@@ -172,7 +172,8 @@ def run_imagenet_with_keras(flags_obj):
metrics=[accuracy],
distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks()
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks(
learning_rate_schedule)
steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size
num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment