Commit cf723306 authored by Shining Sun's avatar Shining Sun
Browse files

bug fixes

parent 3fd9c7fe
...@@ -175,7 +175,8 @@ def run_cifar_with_keras(flags_obj): ...@@ -175,7 +175,8 @@ def run_cifar_with_keras(flags_obj):
optimizer=opt, optimizer=opt,
metrics=[accuracy], metrics=[accuracy],
distribute=strategy) distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks() time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks(
learning_rate_schedule)
steps_per_epoch = cifar_main._NUM_IMAGES['train'] // flags_obj.batch_size steps_per_epoch = cifar_main._NUM_IMAGES['train'] // flags_obj.batch_size
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] // num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
...@@ -187,7 +188,7 @@ def run_cifar_with_keras(flags_obj): ...@@ -187,7 +188,7 @@ def run_cifar_with_keras(flags_obj):
callbacks=[ callbacks=[
time_callback, time_callback,
lr_callback, lr_callback,
tesorboard_callback tensorboard_callback
], ],
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=eval_input_dataset, validation_data=eval_input_dataset,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Runs a ResNet model on the ImageNet dataset.""" """Common util functions an classes used by both keras cifar and imagenet."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -36,6 +36,9 @@ from official.utils.misc import distribution_utils ...@@ -36,6 +36,9 @@ from official.utils.misc import distribution_utils
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
FLAGS = flags.FLAGS
class TimeHistory(tf.keras.callbacks.Callback): class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models.""" """Callback for Keras models."""
...@@ -122,7 +125,7 @@ def get_optimizer_loss_and_metrics(): ...@@ -122,7 +125,7 @@ def get_optimizer_loss_and_metrics():
opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9) opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
# TF Optimizer: # TF Optimizer:
# learning_rate = BASE_LEARNING_RATE * flags_obj.batch_size / 256 # learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256
# opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) # opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
loss = 'categorical_crossentropy' loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy' accuracy = 'categorical_accuracy'
...@@ -131,26 +134,26 @@ def get_optimizer_loss_and_metrics(): ...@@ -131,26 +134,26 @@ def get_optimizer_loss_and_metrics():
def get_dist_strategy(): def get_dist_strategy():
if flags_obj.num_gpus == 1 and flags_obj.dist_strat_off: if FLAGS.num_gpus == 1 and FLAGS.dist_strat_off:
print('Not using distribution strategies.') print('Not using distribution strategies.')
strategy = None strategy = None
else: else:
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus) num_gpus=FLAGS.num_gpus)
return strategy return strategy
def get_fit_callbacks(): def get_fit_callbacks(learning_rate_schedule_fn):
time_callback = keras_common.TimeHistory(flags_obj.batch_size) time_callback = TimeHistory(FLAGS.batch_size)
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir) log_dir=FLAGS.model_dir)
#update_freq="batch") # Add this if want per batch logging. #update_freq="batch") # Add this if want per batch logging.
lr_callback = keras_common.LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule, learning_rate_schedule_fn,
batch_size=flags_obj.batch_size, batch_size=FLAGS.batch_size,
num_images=imagenet_main._NUM_IMAGES['train']) num_images=imagenet_main._NUM_IMAGES['train'])
return time_callback, tensorboard_callback, lr_callback return time_callback, tensorboard_callback, lr_callback
......
...@@ -172,7 +172,8 @@ def run_imagenet_with_keras(flags_obj): ...@@ -172,7 +172,8 @@ def run_imagenet_with_keras(flags_obj):
metrics=[accuracy], metrics=[accuracy],
distribute=strategy) distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks() time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks(
learning_rate_schedule)
steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size
num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] // num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment