"docs/vscode:/vscode.git/clone" did not exist on "daf3559a792a2f342cc6794a76d8093a322c71e0"
Commit cf723306 authored by Shining Sun's avatar Shining Sun
Browse files

bug fixes

parent 3fd9c7fe
...@@ -175,7 +175,8 @@ def run_cifar_with_keras(flags_obj): ...@@ -175,7 +175,8 @@ def run_cifar_with_keras(flags_obj):
optimizer=opt, optimizer=opt,
metrics=[accuracy], metrics=[accuracy],
distribute=strategy) distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks() time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks(
learning_rate_schedule)
steps_per_epoch = cifar_main._NUM_IMAGES['train'] // flags_obj.batch_size steps_per_epoch = cifar_main._NUM_IMAGES['train'] // flags_obj.batch_size
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] // num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
...@@ -187,7 +188,7 @@ def run_cifar_with_keras(flags_obj): ...@@ -187,7 +188,7 @@ def run_cifar_with_keras(flags_obj):
callbacks=[ callbacks=[
time_callback, time_callback,
lr_callback, lr_callback,
tesorboard_callback tensorboard_callback
], ],
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=eval_input_dataset, validation_data=eval_input_dataset,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Runs a ResNet model on the ImageNet dataset.""" """Common util functions an classes used by both keras cifar and imagenet."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -36,6 +36,9 @@ from official.utils.misc import distribution_utils ...@@ -36,6 +36,9 @@ from official.utils.misc import distribution_utils
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
FLAGS = flags.FLAGS
class TimeHistory(tf.keras.callbacks.Callback): class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models.""" """Callback for Keras models."""
...@@ -122,7 +125,7 @@ def get_optimizer_loss_and_metrics(): ...@@ -122,7 +125,7 @@ def get_optimizer_loss_and_metrics():
opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9) opt = gradient_descent_v2.SGD(learning_rate=0.1, momentum=0.9)
# TF Optimizer: # TF Optimizer:
# learning_rate = BASE_LEARNING_RATE * flags_obj.batch_size / 256 # learning_rate = BASE_LEARNING_RATE * FLAGS.batch_size / 256
# opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) # opt = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
loss = 'categorical_crossentropy' loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy' accuracy = 'categorical_accuracy'
...@@ -131,26 +134,26 @@ def get_optimizer_loss_and_metrics(): ...@@ -131,26 +134,26 @@ def get_optimizer_loss_and_metrics():
def get_dist_strategy(): def get_dist_strategy():
if flags_obj.num_gpus == 1 and flags_obj.dist_strat_off: if FLAGS.num_gpus == 1 and FLAGS.dist_strat_off:
print('Not using distribution strategies.') print('Not using distribution strategies.')
strategy = None strategy = None
else: else:
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus) num_gpus=FLAGS.num_gpus)
return strategy return strategy
def get_fit_callbacks():
time_callback = keras_common.TimeHistory(flags_obj.batch_size) def get_fit_callbacks(learning_rate_schedule_fn):
time_callback = TimeHistory(FLAGS.batch_size)
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir) log_dir=FLAGS.model_dir)
#update_freq="batch") # Add this if want per batch logging. #update_freq="batch") # Add this if want per batch logging.
lr_callback = keras_common.LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule, learning_rate_schedule_fn,
batch_size=flags_obj.batch_size, batch_size=FLAGS.batch_size,
num_images=imagenet_main._NUM_IMAGES['train']) num_images=imagenet_main._NUM_IMAGES['train'])
return time_callback, tensorboard_callback, lr_callback return time_callback, tensorboard_callback, lr_callback
...@@ -165,4 +168,4 @@ def analyze_eval_result(eval_output): ...@@ -165,4 +168,4 @@ def analyze_eval_result(eval_output):
print('top_1 accuracy:{}'.format(stats['accuracy_top_1'])) print('top_1 accuracy:{}'.format(stats['accuracy_top_1']))
print('top_1_training_accuracy:{}'.format(stats['training_accuracy_top_1'])) print('top_1_training_accuracy:{}'.format(stats['training_accuracy_top_1']))
return stats return stats
\ No newline at end of file
...@@ -114,7 +114,7 @@ def run_imagenet_with_keras(flags_obj): ...@@ -114,7 +114,7 @@ def run_imagenet_with_keras(flags_obj):
""" """
if flags_obj.enable_eager: if flags_obj.enable_eager:
tf.enable_eager_execution() tf.enable_eager_execution()
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
if dtype == 'fp16': if dtype == 'fp16':
raise ValueError('dtype fp16 is not supported in Keras. Use the default ' raise ValueError('dtype fp16 is not supported in Keras. Use the default '
...@@ -166,18 +166,19 @@ def run_imagenet_with_keras(flags_obj): ...@@ -166,18 +166,19 @@ def run_imagenet_with_keras(flags_obj):
model = resnet_model_tpu.ResNet50(num_classes=imagenet_main._NUM_CLASSES) model = resnet_model_tpu.ResNet50(num_classes=imagenet_main._NUM_CLASSES)
weights=None) weights=None)
model.compile(loss=loss, model.compile(loss=loss,
optimizer=opt, optimizer=opt,
metrics=[accuracy], metrics=[accuracy],
distribute=strategy) distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks() time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks(
learning_rate_schedule)
steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size
num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] // num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
model.fit(train_input_dataset, model.fit(train_input_dataset,
epochs=flags_obj.train_epochs, epochs=flags_obj.train_epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
...@@ -189,7 +190,7 @@ def run_imagenet_with_keras(flags_obj): ...@@ -189,7 +190,7 @@ def run_imagenet_with_keras(flags_obj):
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=eval_input_dataset, validation_data=eval_input_dataset,
verbose=1) verbose=1)
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
...@@ -200,8 +201,8 @@ def run_imagenet_with_keras(flags_obj): ...@@ -200,8 +201,8 @@ def run_imagenet_with_keras(flags_obj):
def define_keras_imagenet_flags(): def define_keras_imagenet_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?') flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
def main(_): def main(_):
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
run_imagenet_with_keras(flags.FLAGS) run_imagenet_with_keras(flags.FLAGS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment