Commit 1b3c9ba6 authored by Toby Boyd's avatar Toby Boyd
Browse files

Add reporting of stats and turn off dist_strat.

parent 46cc8af1
...@@ -45,6 +45,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -45,6 +45,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
""" """
self._batch_size = batch_size self._batch_size = batch_size
self.last_exp_per_sec = 0
super(TimeHistory, self).__init__() super(TimeHistory, self).__init__()
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
...@@ -69,6 +70,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -69,6 +70,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
last_n_batches = time.time() - self.batch_time_start last_n_batches = time.time() - self.batch_time_start
examples_per_second = (self._batch_size * n) / last_n_batches examples_per_second = (self._batch_size * n) / last_n_batches
self.batch_times_secs.append(last_n_batches) self.batch_times_secs.append(last_n_batches)
self.last_exp_per_sec = examples_per_second
self.record_batch = True self.record_batch = True
# TODO(anjalisridhar): add timestamp as well. # TODO(anjalisridhar): add timestamp as well.
if batch != 0: if batch != 0:
...@@ -242,7 +244,6 @@ def run_cifar_with_keras(flags_obj): ...@@ -242,7 +244,6 @@ def run_cifar_with_keras(flags_obj):
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
# Use Keras ResNet50 applications model and native keras APIs # Use Keras ResNet50 applications model and native keras APIs
# initialize RMSprop optimizer # initialize RMSprop optimizer
# TODO(anjalisridhar): Move to using MomentumOptimizer. # TODO(anjalisridhar): Move to using MomentumOptimizer.
...@@ -262,11 +263,11 @@ def run_cifar_with_keras(flags_obj): ...@@ -262,11 +263,11 @@ def run_cifar_with_keras(flags_obj):
classes=cifar_main._NUM_CLASSES, classes=cifar_main._NUM_CLASSES,
weights=None) weights=None)
loss = 'categorical_crossentropy' loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy' accuracy = 'categorical_accuracy'
if flags_obj.num_gpus == 1: if flags_obj.num_gpus == 1 and flags_obj.dist_strat_off:
print('Not using distribution strategies.')
model.compile(loss=loss, model.compile(loss=loss,
optimizer=opt, optimizer=opt,
metrics=[accuracy]) metrics=[accuracy])
...@@ -282,7 +283,7 @@ def run_cifar_with_keras(flags_obj): ...@@ -282,7 +283,7 @@ def run_cifar_with_keras(flags_obj):
tesorboard_callback = tf.keras.callbacks.TensorBoard( tesorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir) log_dir=flags_obj.model_dir)
#update_freq="batch") # Add this if want per batch logging. # update_freq="batch") # Add this if want per batch logging.
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule, learning_rate_schedule,
...@@ -292,8 +293,8 @@ def run_cifar_with_keras(flags_obj): ...@@ -292,8 +293,8 @@ def run_cifar_with_keras(flags_obj):
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] // num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
print("Executing eagerly?:", tf.executing_eagerly()) print('Executing eagerly?:', tf.executing_eagerly())
model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=flags_obj.train_epochs, epochs=flags_obj.train_epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
callbacks=[ callbacks=[
...@@ -306,16 +307,29 @@ def run_cifar_with_keras(flags_obj): ...@@ -306,16 +307,29 @@ def run_cifar_with_keras(flags_obj):
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
print('Test loss:', eval_output[0])
def main(_): stats = {}
stats['accuracy_top_1'] = eval_output[1]
stats['eval_loss'] = eval_output[0]
stats['training_loss'] = history.history['loss'][-1]
stats['training_accuracy_top_1'] = history.history['categorical_accuracy'][-1]
print('top_1 accuracy:{}'.format(stats['accuracy_top_1']))
print('top_1_training_accuracy:{}'.format(stats['training_accuracy_top_1']))
return stats
def define_keras_cifar_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
def main(_):
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
run_cifar_with_keras(flags.FLAGS) run_cifar_with_keras(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.DEBUG) tf.logging.set_verbosity(tf.logging.DEBUG)
define_keras_cifar_flags()
cifar_main.define_cifar_flags() cifar_main.define_cifar_flags()
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
absl_app.run(main) absl_app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment