Commit 9baefd8a authored by Priya Gupta's avatar Priya Gupta
Browse files

Merge branch 'cifar_keras' of https://github.com/tensorflow/models into cifar_keras

parents aa1d6176 52ee9636
"""Executes Keras benchmarks and accuracy tests."""
from __future__ import print_function
import os
import sys
from absl import app as absl_app
from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main
import official.resnet.keras.keras_cifar_main as keras_cifar_main
DATA_DIR = '/data/cifar10_data/'
class KerasCifar10BenchmarkTests():
def keras_resnet56_1_gpu(self):
self._setup()
flags.FLAGS.num_gpus = 1
flags.FLAGS.data_dir = DATA_DIR
flags.FLAGS.batch_size = 128
flags.FLAGS.train_epochs = 1
flags.FLAGS.model_dir = self._get_model_dir('keras_resnet56_1_gpu')
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
stats = keras_cifar_main.run_cifar_with_keras(flags.FLAGS)
report_info = {}
results = []
results.append(self._create_result(stats['accuracy_top_1'].item(),
'top_1',
'quality'))
results.append(self._create_result(stats['training_accuracy_top_1'].item(),
'top_1_train_accuracy',
'quality'))
report_info['results'] = results
return report_info
def keras_resnet56_4_gpu(self):
flags.FLAGS.num_gpus = 4
flags.FLAGS.data_dir = DATA_DIR
flags.FLAGS.batch_size = 128
flags.FLAGS.train_epochs = 182
flags.FLAGS.model_dir = ''
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
keras_cifar_main.run_cifar_with_keras(flags.FLAGS)
def keras_resnet56_no_dist_strat_1_gpu(self):
self._setup()
flags.dist_strat_off = True
flags.FLAGS.num_gpus = 1
flags.FLAGS.data_dir = DATA_DIR
flags.FLAGS.batch_size = 128
flags.FLAGS.train_epochs = 1
flags.FLAGS.model_dir = ''
flags.FLAGS.resnet_size = 56
flags.FLAGS.dtype = 'fp32'
stats = keras_cifar_main.run_cifar_with_keras(flags.FLAGS)
report_info = {}
results = []
results.append(self._create_result(stats['accuracy_top_1'].item(),
'top_1',
'quality'))
results.append(self._create_result(stats['training_accuracy_top_1'].item(),
'top_1_train_accuracy',
'quality'))
report_info['results'] = results
return report_info
def _create_result(self, result, result_name, result_unit):
res_dict = {}
res_dict['result'] = result
res_dict['result_name'] = result_name
res_dict['result_unit'] = result_unit
return res_dict
def _get_model_dir(self, folder_name):
return os.path.join('/workspace', folder_name)
def _setup(self):
tf.logging.set_verbosity(tf.logging.DEBUG)
keras_cifar_main.define_keras_cifar_flags()
cifar_main.define_cifar_flags()
flags.FLAGS(['foo'])
def run_tests(self, test_list):
keras_benchmark = KerasCifar10BenchmarkTests()
if test_list:
for t in test_list:
getattr(self, t)()
else:
print('Running all tests')
keras_benchmark.keras_resnet56_1_gpu()
keras_benchmark.keras_resnet56_no_dist_strat_1_gpu()
keras_benchmark.keras_resnet56_4_gpu()
def main(_):
keras_benchmark = KerasCifar10BenchmarkTests()
keras_benchmark.run_tests(['keras_resnet56_1_gpu'])
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.DEBUG)
cifar_main.define_cifar_flags()
absl_app.run(main)
...@@ -45,6 +45,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -45,6 +45,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
""" """
self._batch_size = batch_size self._batch_size = batch_size
self.last_exp_per_sec = 0
super(TimeHistory, self).__init__() super(TimeHistory, self).__init__()
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
...@@ -69,6 +70,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -69,6 +70,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
last_n_batches = time.time() - self.batch_time_start last_n_batches = time.time() - self.batch_time_start
examples_per_second = (self._batch_size * n) / last_n_batches examples_per_second = (self._batch_size * n) / last_n_batches
self.batch_times_secs.append(last_n_batches) self.batch_times_secs.append(last_n_batches)
self.last_exp_per_sec = examples_per_second
self.record_batch = True self.record_batch = True
# TODO(anjalisridhar): add timestamp as well. # TODO(anjalisridhar): add timestamp as well.
if batch != 0: if batch != 0:
...@@ -80,8 +82,8 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -80,8 +82,8 @@ class TimeHistory(tf.keras.callbacks.Callback):
# LR_SCHEDULE = [ # (multiplier, epoch to start) tuples # LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
# (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) # (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
# ] # ]
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(0.1, 91), (0.01, 136), (0.001, 182) (0.1, 91), (0.01, 136), (0.001, 182)
] ]
BASE_LEARNING_RATE = 0.1 BASE_LEARNING_RATE = 0.1
...@@ -229,19 +231,18 @@ def run_cifar_with_keras(flags_obj): ...@@ -229,19 +231,18 @@ def run_cifar_with_keras(flags_obj):
else: else:
train_input_dataset = cifar_main.input_fn( train_input_dataset = cifar_main.input_fn(
True, True,
flags_obj.data_dir, flags_obj.data_dir,
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
eval_input_dataset = cifar_main.input_fn( eval_input_dataset = cifar_main.input_fn(
False, False,
flags_obj.data_dir, flags_obj.data_dir,
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
# Use Keras ResNet50 applications model and native keras APIs # Use Keras ResNet50 applications model and native keras APIs
# initialize RMSprop optimizer # initialize RMSprop optimizer
...@@ -253,20 +254,20 @@ def run_cifar_with_keras(flags_obj): ...@@ -253,20 +254,20 @@ def run_cifar_with_keras(flags_obj):
# TF Optimizer: # TF Optimizer:
# opt = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=0.9) # opt = tf.train.MomentumOptimizer(learning_rate=0.1, momentum=0.9)
strategy = distribution_utils.get_distribution_strategy( strategy = distribution_utils.get_distribution_strategy(
num_gpus=flags_obj.num_gpus) num_gpus=flags_obj.num_gpus)
model = keras_resnet_model.ResNet56(input_shape=(32, 32, 3), model = keras_resnet_model.ResNet56(input_shape=(32, 32, 3),
include_top=True, include_top=True,
classes=cifar_main._NUM_CLASSES, classes=cifar_main._NUM_CLASSES,
weights=None) weights=None)
loss = 'categorical_crossentropy' loss = 'categorical_crossentropy'
accuracy = 'categorical_accuracy' accuracy = 'categorical_accuracy'
if flags_obj.num_gpus == 1: if flags_obj.num_gpus == 1 and flags_obj.dist_strat_off:
print('Not using distribution strategies.')
model.compile(loss=loss, model.compile(loss=loss,
optimizer=opt, optimizer=opt,
metrics=[accuracy]) metrics=[accuracy])
...@@ -281,43 +282,56 @@ def run_cifar_with_keras(flags_obj): ...@@ -281,43 +282,56 @@ def run_cifar_with_keras(flags_obj):
time_callback = TimeHistory(flags_obj.batch_size) time_callback = TimeHistory(flags_obj.batch_size)
tesorboard_callback = tf.keras.callbacks.TensorBoard( tesorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=flags_obj.model_dir) log_dir=flags_obj.model_dir)
#update_freq="batch") # Add this if want per batch logging. # update_freq="batch") # Add this if want per batch logging.
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule, learning_rate_schedule,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_images=cifar_main._NUM_IMAGES['train']) num_images=cifar_main._NUM_IMAGES['train'])
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] // num_eval_steps = (cifar_main._NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
print("Executing eagerly?:", tf.executing_eagerly()) print('Executing eagerly?:', tf.executing_eagerly())
model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=flags_obj.train_epochs, epochs=flags_obj.train_epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
callbacks=[ callbacks=[
time_callback, time_callback,
lr_callback, lr_callback,
tesorboard_callback tesorboard_callback
], ],
validation_steps=num_eval_steps, validation_steps=num_eval_steps,
validation_data=eval_input_dataset, validation_data=eval_input_dataset,
verbose=1) verbose=1)
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
print('Test loss:', eval_output[0])
def main(_): stats = {}
stats['accuracy_top_1'] = eval_output[1]
stats['eval_loss'] = eval_output[0]
stats['training_loss'] = history.history['loss'][-1]
stats['training_accuracy_top_1'] = history.history['categorical_accuracy'][-1]
print('top_1 accuracy:{}'.format(stats['accuracy_top_1']))
print('top_1_training_accuracy:{}'.format(stats['training_accuracy_top_1']))
return stats
def define_keras_cifar_flags():
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
def main(_):
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
run_cifar_with_keras(flags.FLAGS) run_cifar_with_keras(flags.FLAGS)
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.DEBUG) tf.logging.set_verbosity(tf.logging.DEBUG)
define_keras_cifar_flags()
cifar_main.define_cifar_flags() cifar_main.define_cifar_flags()
flags.DEFINE_boolean(name='enable_eager', default=False, help='Enable eager?')
absl_app.run(main) absl_app.run(main)
...@@ -628,7 +628,10 @@ def define_resnet_flags(resnet_size_choices=None): ...@@ -628,7 +628,10 @@ def define_resnet_flags(resnet_size_choices=None):
'the expense of image resize/cropping being done as part of model ' 'the expense of image resize/cropping being done as part of model '
'inference. Note, this flag only applies to ImageNet and cannot ' 'inference. Note, this flag only applies to ImageNet and cannot '
'be used for CIFAR.')) 'be used for CIFAR.'))
flags.DEFINE_boolean(
name='dist_strat_off', default=False,
help=flags_core.help_wrap('Set to true to not use distribution '
'strategies.'))
choice_kwargs = dict( choice_kwargs = dict(
name='resnet_size', short_name='rs', default='50', name='resnet_size', short_name='rs', default='50',
help=flags_core.help_wrap('The size of the ResNet model to use.')) help=flags_core.help_wrap('The size of the ResNet model to use.'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment