Commit 424c2045 authored by Shining Sun's avatar Shining Sun
Browse files

Before all the data related change

parent 53ff5d90
...@@ -29,13 +29,13 @@ from official.utils.logs import logger ...@@ -29,13 +29,13 @@ from official.utils.logs import logger
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
_HEIGHT = 32 HEIGHT = 32
_WIDTH = 32 WIDTH = 32
_NUM_CHANNELS = 3 NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = _HEIGHT * _WIDTH * _NUM_CHANNELS _DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS
# The record is the image plus a one-byte label # The record is the image plus a one-byte label
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1 _RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
_NUM_CLASSES = 10 NUM_CLASSES = 10
_NUM_DATA_FILES = 5 _NUM_DATA_FILES = 5
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits. # TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
...@@ -79,7 +79,7 @@ def parse_record(raw_record, is_training, dtype): ...@@ -79,7 +79,7 @@ def parse_record(raw_record, is_training, dtype):
# The remaining bytes after the label represent the image, which we reshape # The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width]. # from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(record_vector[1:_RECORD_BYTES], depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
[_NUM_CHANNELS, _HEIGHT, _WIDTH]) [NUM_CHANNELS, HEIGHT, WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as # Convert from [depth, height, width] to [height, width, depth], and cast as
# float32. # float32.
...@@ -96,10 +96,10 @@ def preprocess_image(image, is_training): ...@@ -96,10 +96,10 @@ def preprocess_image(image, is_training):
if is_training: if is_training:
# Resize the image to add four extra pixels on each side. # Resize the image to add four extra pixels on each side.
image = tf.image.resize_image_with_crop_or_pad( image = tf.image.resize_image_with_crop_or_pad(
image, _HEIGHT + 8, _WIDTH + 8) image, HEIGHT + 8, WIDTH + 8)
# Randomly crop a [_HEIGHT, _WIDTH] section of the image. # Randomly crop a [HEIGHT, WIDTH] section of the image.
image = tf.random_crop(image, [_HEIGHT, _WIDTH, _NUM_CHANNELS]) image = tf.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])
# Randomly flip the image horizontally. # Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_left_right(image)
...@@ -145,7 +145,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -145,7 +145,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def get_synth_input_fn(dtype): def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn( return resnet_run_loop.get_synth_input_fn(
_HEIGHT, _WIDTH, _NUM_CHANNELS, _NUM_CLASSES, dtype=dtype) HEIGHT, WIDTH, NUM_CHANNELS, NUM_CLASSES, dtype=dtype)
############################################################################### ###############################################################################
...@@ -154,7 +154,7 @@ def get_synth_input_fn(dtype): ...@@ -154,7 +154,7 @@ def get_synth_input_fn(dtype):
class Cifar10Model(resnet_model.Model): class Cifar10Model(resnet_model.Model):
"""Model class with appropriate defaults for CIFAR-10 data.""" """Model class with appropriate defaults for CIFAR-10 data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=NUM_CLASSES,
resnet_version=resnet_model.DEFAULT_VERSION, resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for CIFAR-10 data. """These are the parameters that work for CIFAR-10 data.
...@@ -196,7 +196,7 @@ class Cifar10Model(resnet_model.Model): ...@@ -196,7 +196,7 @@ class Cifar10Model(resnet_model.Model):
def cifar10_model_fn(features, labels, mode, params): def cifar10_model_fn(features, labels, mode, params):
"""Model function for CIFAR-10.""" """Model function for CIFAR-10."""
features = tf.reshape(features, [-1, _HEIGHT, _WIDTH, _NUM_CHANNELS]) features = tf.reshape(features, [-1, HEIGHT, WIDTH, NUM_CHANNELS])
# Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under. # Learning rate schedule follows arXiv:1512.03385 for ResNet-56 and under.
learning_rate_fn = resnet_run_loop.learning_rate_with_decay( learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=128, batch_size=params['batch_size'], batch_denom=128,
...@@ -261,7 +261,7 @@ def run_cifar(flags_obj): ...@@ -261,7 +261,7 @@ def run_cifar(flags_obj):
input_fn) input_fn)
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, cifar10_model_fn, input_function, DATASET_NAME, flags_obj, cifar10_model_fn, input_function, DATASET_NAME,
shape=[_HEIGHT, _WIDTH, _NUM_CHANNELS]) shape=[HEIGHT, WIDTH, NUM_CHANNELS])
def main(_): def main(_):
......
...@@ -30,11 +30,11 @@ from official.resnet import imagenet_preprocessing ...@@ -30,11 +30,11 @@ from official.resnet import imagenet_preprocessing
from official.resnet import resnet_model from official.resnet import resnet_model
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
_DEFAULT_IMAGE_SIZE = 224 DEFAULT_IMAGE_SIZE = 224
_NUM_CHANNELS = 3 NUM_CHANNELS = 3
_NUM_CLASSES = 1001 NUM_CLASSES = 1001
_NUM_IMAGES = { NUM_IMAGES = {
'train': 1281167, 'train': 1281167,
'validation': 50000, 'validation': 50000,
} }
...@@ -149,9 +149,9 @@ def parse_record(raw_record, is_training, dtype): ...@@ -149,9 +149,9 @@ def parse_record(raw_record, is_training, dtype):
image = imagenet_preprocessing.preprocess_image( image = imagenet_preprocessing.preprocess_image(
image_buffer=image_buffer, image_buffer=image_buffer,
bbox=bbox, bbox=bbox,
output_height=_DEFAULT_IMAGE_SIZE, output_height=DEFAULT_IMAGE_SIZE,
output_width=_DEFAULT_IMAGE_SIZE, output_width=DEFAULT_IMAGE_SIZE,
num_channels=_NUM_CHANNELS, num_channels=NUM_CHANNELS,
is_training=is_training) is_training=is_training)
image = tf.cast(image, dtype) image = tf.cast(image, dtype)
...@@ -206,7 +206,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1, ...@@ -206,7 +206,7 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1,
def get_synth_input_fn(dtype): def get_synth_input_fn(dtype):
return resnet_run_loop.get_synth_input_fn( return resnet_run_loop.get_synth_input_fn(
_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES, DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS, NUM_CLASSES,
dtype=dtype) dtype=dtype)
...@@ -216,7 +216,7 @@ def get_synth_input_fn(dtype): ...@@ -216,7 +216,7 @@ def get_synth_input_fn(dtype):
class ImagenetModel(resnet_model.Model): class ImagenetModel(resnet_model.Model):
"""Model class with appropriate defaults for Imagenet data.""" """Model class with appropriate defaults for Imagenet data."""
def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, def __init__(self, resnet_size, data_format=None, num_classes=NUM_CLASSES,
resnet_version=resnet_model.DEFAULT_VERSION, resnet_version=resnet_model.DEFAULT_VERSION,
dtype=resnet_model.DEFAULT_DTYPE): dtype=resnet_model.DEFAULT_DTYPE):
"""These are the parameters that work for Imagenet data. """These are the parameters that work for Imagenet data.
...@@ -303,7 +303,7 @@ def imagenet_model_fn(features, labels, mode, params): ...@@ -303,7 +303,7 @@ def imagenet_model_fn(features, labels, mode, params):
learning_rate_fn = resnet_run_loop.learning_rate_with_decay( learning_rate_fn = resnet_run_loop.learning_rate_with_decay(
batch_size=params['batch_size'], batch_denom=256, batch_size=params['batch_size'], batch_denom=256,
num_images=_NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90], num_images=NUM_IMAGES['train'], boundary_epochs=[30, 60, 80, 90],
decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr) decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr)
return resnet_run_loop.resnet_model_fn( return resnet_run_loop.resnet_model_fn(
...@@ -343,7 +343,7 @@ def run_imagenet(flags_obj): ...@@ -343,7 +343,7 @@ def run_imagenet(flags_obj):
resnet_run_loop.resnet_main( resnet_run_loop.resnet_main(
flags_obj, imagenet_model_fn, input_function, DATASET_NAME, flags_obj, imagenet_model_fn, input_function, DATASET_NAME,
shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS]) shape=[DEFAULT_IMAGE_SIZE, DEFAULT_IMAGE_SIZE, NUM_CHANNELS])
def main(_): def main(_):
......
...@@ -18,11 +18,8 @@ from __future__ import absolute_import ...@@ -18,11 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import cifar10_main as cifar_main from official.resnet import cifar10_main as cifar_main
...@@ -68,6 +65,8 @@ def parse_record_keras(raw_record, is_training, dtype): ...@@ -68,6 +65,8 @@ def parse_record_keras(raw_record, is_training, dtype):
The input record is parsed into a label and image, and the image is passed The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on). through preprocessing steps (cropping, flipping, and so on).
This method converts the label to onhot to fit the loss function.
Args: Args:
raw_record: scalar Tensor tf.string containing a serialized raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer. Example protocol buffer.
...@@ -78,7 +77,7 @@ def parse_record_keras(raw_record, is_training, dtype): ...@@ -78,7 +77,7 @@ def parse_record_keras(raw_record, is_training, dtype):
Tuple with processed image tensor and one-hot-encoded label tensor. Tuple with processed image tensor and one-hot-encoded label tensor.
""" """
image, label = cifar_main.parse_record(raw_record, is_training, dtype) image, label = cifar_main.parse_record(raw_record, is_training, dtype)
label = tf.sparse_to_dense(label, (cifar_main._NUM_CLASSES,), 1) label = tf.sparse_to_dense(label, (cifar_main.NUM_CLASSES,), 1)
return image, label return image, label
...@@ -105,26 +104,26 @@ def run(flags_obj): ...@@ -105,26 +104,26 @@ def run(flags_obj):
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
synth_input_fn = resnet_run_loop.get_synth_input_fn( synth_input_fn = resnet_run_loop.get_synth_input_fn(
cifar_main._HEIGHT, cifar_main._WIDTH, cifar_main.HEIGHT, cifar_main.WIDTH,
cifar_main._NUM_CHANNELS, cifar_main._NUM_CLASSES, cifar_main.NUM_CHANNELS, cifar_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj)) dtype=flags_core.get_tf_dtype(flags_obj))
train_input_dataset = synth_input_fn( train_input_dataset = synth_input_fn(
True, True,
flags_obj.data_dir, flags_obj.data_dir,
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
height=cifar_main._HEIGHT, height=cifar_main.HEIGHT,
width=cifar_main._WIDTH, width=cifar_main.WIDTH,
num_channels=cifar_main._NUM_CHANNELS, num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main._NUM_CLASSES, num_classes=cifar_main.NUM_CLASSES,
dtype=dtype) dtype=dtype)
eval_input_dataset = synth_input_fn( eval_input_dataset = synth_input_fn(
False, False,
flags_obj.data_dir, flags_obj.data_dir,
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
height=cifar_main._HEIGHT, height=cifar_main.HEIGHT,
width=cifar_main._WIDTH, width=cifar_main.WIDTH,
num_channels=cifar_main._NUM_CHANNELS, num_channels=cifar_main.NUM_CHANNELS,
num_classes=cifar_main._NUM_CLASSES, num_classes=cifar_main.NUM_CLASSES,
dtype=dtype) dtype=dtype)
# pylint: enable=protected-access # pylint: enable=protected-access
...@@ -144,20 +143,22 @@ def run(flags_obj): ...@@ -144,20 +143,22 @@ def run(flags_obj):
parse_record_fn=parse_record_keras) parse_record_fn=parse_record_keras)
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
strategy = keras_common.get_dist_strategy() strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.use_one_device_strategy)
model = resnet56.ResNet56(input_shape=(32, 32, 3), model = resnet56.ResNet56(input_shape=(32, 32, 3),
classes=cifar_main._NUM_CLASSES) classes=cifar_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
metrics=['categorical_accuracy'], metrics=['categorical_accuracy'],
strategy=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks( time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule) learning_rate_schedule, cifar_main.NUM_IMAGES['train'])
steps_per_epoch = cifar_main._NUM_IMAGES['train'] // flags_obj.batch_size steps_per_epoch = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size
num_eval_steps = (cifar_main._NUM_IMAGES['validation'] // num_eval_steps = (cifar_main.NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
history = model.fit(train_input_dataset, history = model.fit(train_input_dataset,
...@@ -176,7 +177,6 @@ def run(flags_obj): ...@@ -176,7 +177,6 @@ def run(flags_obj):
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
print('Test loss:', eval_output[0])
stats = keras_common.analyze_fit_and_eval_result(history, eval_output) stats = keras_common.analyze_fit_and_eval_result(history, eval_output)
return stats return stats
...@@ -188,6 +188,6 @@ def main(_): ...@@ -188,6 +188,6 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.DEBUG) tf.logging.set_verbosity(tf.logging.INFO)
cifar_main.define_cifar_flags() cifar_main.define_cifar_flags()
absl_app.run(main) absl_app.run(main)
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Common util functions an classes used by both keras cifar and imagenet.""" """Common util functions and classes used by both keras cifar and imagenet."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
...@@ -20,13 +20,10 @@ from __future__ import print_function ...@@ -20,13 +20,10 @@ from __future__ import print_function
import time import time
from absl import app as absl_app
from absl import flags from absl import flags
import numpy as np import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main
from official.utils.misc import distribution_utils
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2 from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
...@@ -37,7 +34,7 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -37,7 +34,7 @@ class TimeHistory(tf.keras.callbacks.Callback):
"""Callback for Keras models.""" """Callback for Keras models."""
def __init__(self, batch_size): def __init__(self, batch_size):
"""Callback for Keras models. """Callback for logging performance (# image/second).
Args: Args:
batch_size: Total batch size. batch_size: Total batch size.
...@@ -45,28 +42,22 @@ class TimeHistory(tf.keras.callbacks.Callback): ...@@ -45,28 +42,22 @@ class TimeHistory(tf.keras.callbacks.Callback):
""" """
self._batch_size = batch_size self._batch_size = batch_size
super(TimeHistory, self).__init__() super(TimeHistory, self).__init__()
self.log_batch_size = 100
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
self.epoch_times_secs = []
self.batch_times_secs = [] self.batch_times_secs = []
self.record_batch = True self.record_batch = True
def on_epoch_begin(self, epoch, logs=None):
self.epoch_time_start = time.time()
def on_epoch_end(self, epoch, logs=None):
self.epoch_times_secs.append(time.time() - self.epoch_time_start)
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
if self.record_batch: if self.record_batch:
self.batch_time_start = time.time() self.batch_time_start = time.time()
self.record_batch = False self.record_batch = False
def on_batch_end(self, batch, logs=None): def on_batch_end(self, batch, logs=None):
n = 100 if batch % self.log_batch_size == 0:
if batch % n == 0:
last_n_batches = time.time() - self.batch_time_start last_n_batches = time.time() - self.batch_time_start
examples_per_second = (self._batch_size * n) / last_n_batches examples_per_second =
(self._batch_size * self.log_batch_size) / last_n_batches
self.batch_times_secs.append(last_n_batches) self.batch_times_secs.append(last_n_batches)
self.record_batch = True self.record_batch = True
# TODO(anjalisridhar): add timestamp as well. # TODO(anjalisridhar): add timestamp as well.
...@@ -95,8 +86,8 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback): ...@@ -95,8 +86,8 @@ class LearningRateBatchScheduler(tf.keras.callbacks.Callback):
self.prev_lr = -1 self.prev_lr = -1
def on_epoch_begin(self, epoch, logs=None): def on_epoch_begin(self, epoch, logs=None):
#if not hasattr(self.model.optimizer, 'learning_rate'): if not hasattr(self.model.optimizer, 'learning_rate'):
# raise ValueError('Optimizer must have a "learning_rate" attribute.') raise ValueError('Optimizer must have a "learning_rate" attribute.')
self.epochs += 1 self.epochs += 1
def on_batch_begin(self, batch, logs=None): def on_batch_begin(self, batch, logs=None):
...@@ -120,31 +111,16 @@ def get_optimizer(): ...@@ -120,31 +111,16 @@ def get_optimizer():
return optimizer return optimizer
def get_dist_strategy(): def get_callbacks(learning_rate_schedule_fn, num_images):
if FLAGS.num_gpus == 1 and not FLAGS.use_one_device_strategy:
print('Not using distribution strategies.')
strategy = None
elif FLAGS.num_gpus > 1 and FLAGS.use_one_device_strategy:
rase ValueError("When %d GPUs are specified, use_one_device_strategy'
'flag cannot be set to True.")
else:
strategy = distribution_utils.get_distribution_strategy(
num_gpus=FLAGS.num_gpus)
return strategy
def get_fit_callbacks(learning_rate_schedule_fn):
time_callback = TimeHistory(FLAGS.batch_size) time_callback = TimeHistory(FLAGS.batch_size)
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir) log_dir=FLAGS.model_dir)
#update_freq="batch") # Add this if want per batch logging.
lr_callback = LearningRateBatchScheduler( lr_callback = LearningRateBatchScheduler(
learning_rate_schedule_fn, learning_rate_schedule_fn,
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
num_images=imagenet_main._NUM_IMAGES['train']) num_images=num_images)
return time_callback, tensorboard_callback, lr_callback return time_callback, tensorboard_callback, lr_callback
...@@ -155,6 +131,7 @@ def analyze_fit_and_eval_result(history, eval_output): ...@@ -155,6 +131,7 @@ def analyze_fit_and_eval_result(history, eval_output):
stats['training_loss'] = history.history['loss'][-1] stats['training_loss'] = history.history['loss'][-1]
stats['training_accuracy_top_1'] = history.history['categorical_accuracy'][-1] stats['training_accuracy_top_1'] = history.history['categorical_accuracy'][-1]
print('Test loss:{}'.format(stats['']))
print('top_1 accuracy:{}'.format(stats['accuracy_top_1'])) print('top_1 accuracy:{}'.format(stats['accuracy_top_1']))
print('top_1_training_accuracy:{}'.format(stats['training_accuracy_top_1'])) print('top_1_training_accuracy:{}'.format(stats['training_accuracy_top_1']))
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,15 +18,11 @@ from __future__ import absolute_import ...@@ -18,15 +18,11 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import time
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import numpy as np
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet import imagenet_preprocessing
from official.resnet import resnet_run_loop from official.resnet import resnet_run_loop
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet50 from official.resnet.keras import resnet50
...@@ -104,22 +100,22 @@ def run_imagenet_with_keras(flags_obj): ...@@ -104,22 +100,22 @@ def run_imagenet_with_keras(flags_obj):
# pylint: disable=protected-access # pylint: disable=protected-access
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
synth_input_fn = resnet_run_loop.get_synth_input_fn( synth_input_fn = resnet_run_loop.get_synth_input_fn(
imagenet_main._DEFAULT_IMAGE_SIZE, imagenet_main._DEFAULT_IMAGE_SIZE, imagenet_main.DEFAULT_IMAGE_SIZE, imagenet_main.DEFAULT_IMAGE_SIZE,
imagenet_main._NUM_CHANNELS, imagenet_main._NUM_CLASSES, imagenet_main.NUM_CHANNELS, imagenet_main.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj)) dtype=flags_core.get_tf_dtype(flags_obj))
train_input_dataset = synth_input_fn( train_input_dataset = synth_input_fn(
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
height=imagenet_main._DEFAULT_IMAGE_SIZE, height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main._DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main._NUM_CHANNELS, num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main._NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype) dtype=dtype)
eval_input_dataset = synth_input_fn( eval_input_dataset = synth_input_fn(
batch_size=per_device_batch_size, batch_size=per_device_batch_size,
height=imagenet_main._DEFAULT_IMAGE_SIZE, height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main._DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main._NUM_CHANNELS, num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main._NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype) dtype=dtype)
# pylint: enable=protected-access # pylint: enable=protected-access
...@@ -140,20 +136,21 @@ def run_imagenet_with_keras(flags_obj): ...@@ -140,20 +136,21 @@ def run_imagenet_with_keras(flags_obj):
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
strategy = keras_common.get_dist_strategy() strategy = distribution_utils.get_distribution_strategy(
flags_obj.num_gpus, flags_obj.use_one_device_strategy)
model = resnet50.ResNet50(num_classes=imagenet_main._NUM_CLASSES) model = resnet50.ResNet50(num_classes=imagenet_main.NUM_CLASSES)
model.compile(loss='categorical_crossentropy', model.compile(loss='categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
metrics=['categorical_accuracy'], metrics=['categorical_accuracy'],
distribute=strategy) distribute=strategy)
time_callback, tensorboard_callback, lr_callback = keras_common.get_fit_callbacks( time_callback, tensorboard_callback, lr_callback = keras_common.get_callbacks(
learning_rate_schedule) learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
steps_per_epoch = imagenet_main._NUM_IMAGES['train'] // flags_obj.batch_size steps_per_epoch = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
num_eval_steps = (imagenet_main._NUM_IMAGES['validation'] // num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
history = model.fit(train_input_dataset, history = model.fit(train_input_dataset,
...@@ -172,7 +169,6 @@ def run_imagenet_with_keras(flags_obj): ...@@ -172,7 +169,6 @@ def run_imagenet_with_keras(flags_obj):
steps=num_eval_steps, steps=num_eval_steps,
verbose=1) verbose=1)
print('Test loss:', eval_output[0])
stats = keras_common.analyze_fit_and_eval_result(history, eval_output) stats = keras_common.analyze_fit_and_eval_result(history, eval_output)
return stats return stats
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment