Unverified Commit 23c0017f authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Merged commit includes the following changes: (#7368)

* Merged commit includes the following changes:
261380794  by haoyuzhang<haoyuzhang@google.com>:

    Internal change

261374439  by haoyuzhang<haoyuzhang@google.com>:

    Change Keras CTL dependencies from r1 ResNet to Keras ResNet.

--

PiperOrigin-RevId: 261380794

* Revert unintentional change

* Revert unintentional change
parent a76e250f
......@@ -21,19 +21,9 @@ from __future__ import print_function
from absl import flags
FLAGS = flags.FLAGS
def define_ctl_flags():
"""Define flags for CTL."""
flags.DEFINE_boolean(name='use_tf_function', default=True,
help='Wrap the train and test step inside a '
'tf.function.')
flags.DEFINE_boolean(name='skip_eval', default=False, help='Skip evaluation?')
flags.DEFINE_integer(
name='train_steps', default=None,
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.')
\ No newline at end of file
......@@ -22,7 +22,7 @@ import time
from absl import flags
import tensorflow as tf
from official.resnet import imagenet_main
from official.resnet.keras import keras_common
from official.resnet.ctl import ctl_imagenet_main
from official.resnet.ctl import ctl_common
from official.utils.testing.perfzero_benchmark import PerfZeroBenchmark
......@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods = [
ctl_common.define_ctl_flags,
lambda: imagenet_main.define_imagenet_flags()
keras_common.define_keras_flags
]
self.data_dir = os.path.join(root_data_dir, 'imagenet')
......@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def __init__(self, output_dir=None, default_flags=None):
flag_methods = [
ctl_common.define_ctl_flags,
lambda: imagenet_main.define_imagenet_flags()
keras_common.define_keras_flags
]
super(Resnet50CtlBenchmarkBase, self).__init__(
......
......@@ -18,39 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tempfile
import datetime
import time
import numpy as np
from absl import app as absl_app
from absl import flags
from absl import logging
import tensorflow as tf
from official.resnet import imagenet_main
from official.resnet.ctl import ctl_common
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.resnet.keras import keras_imagenet_main
from official.resnet.keras import resnet_model
from official.utils.flags import core as flags_core
from official.utils.logs import logger
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
from official.resnet.ctl import ctl_common
from official.utils.misc import keras_utils
def parse_record_keras(raw_record, is_training, dtype):
"""Adjust the shape of label."""
image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32)
return image, label
from official.utils.misc import model_helpers
def build_stats(train_result, eval_result, time_callback):
......@@ -68,18 +50,18 @@ def build_stats(train_result, eval_result, time_callback):
stats = {}
if eval_result:
stats["eval_loss"] = eval_result[0]
stats["eval_acc"] = eval_result[1]
stats['eval_loss'] = eval_result[0]
stats['eval_acc'] = eval_result[1]
stats['train_loss'] = train_result[0]
stats['train_acc'] = train_result[1]
if time_callback:
timestamp_log = time_callback.timestamp_log
stats["step_timestamp_log"] = timestamp_log
stats["train_finish_time"] = time_callback.train_finish_time
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = time_callback.train_finish_time
if len(timestamp_log) > 1:
stats["avg_exp_per_second"] = (
stats['avg_exp_per_second'] = (
time_callback.batch_size * time_callback.log_steps *
(len(time_callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
......@@ -92,20 +74,20 @@ def get_input_dataset(flags_obj, strategy):
dtype = flags_core.get_tf_dtype(flags_obj)
if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=imagenet_main.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES,
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_preprocessing.NUM_CHANNELS,
num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype,
drop_remainder=True)
else:
input_fn = imagenet_main.input_fn
input_fn = imagenet_preprocessing.input_fn
train_ds = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
parse_record_fn=parse_record_keras,
parse_record_fn=imagenet_preprocessing.parse_record,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
......@@ -118,7 +100,7 @@ def get_input_dataset(flags_obj, strategy):
is_training=False,
data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size,
parse_record_fn=parse_record_keras,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=dtype)
if strategy:
......@@ -129,14 +111,16 @@ def get_input_dataset(flags_obj, strategy):
def get_num_train_iterations(flags_obj):
"""Returns the number of training stesps, train and test epochs."""
train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size
train_steps = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs
if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1
eval_steps = imagenet_main.NUM_IMAGES['validation'] // flags_obj.batch_size
eval_steps = (
imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
return train_steps, train_epochs, eval_steps
......@@ -177,7 +161,8 @@ def run(flags_obj):
strategy_scope = distribution_utils.get_strategy_scope(strategy)
with strategy_scope:
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, batch_size=flags_obj.batch_size)
optimizer = tf.keras.optimizers.SGD(
......@@ -296,6 +281,6 @@ def main(_):
if __name__ == '__main__':
logging.set_verbosity(logging.INFO)
imagenet_main.define_imagenet_flags()
keras_common.define_keras_flags()
ctl_common.define_ctl_flags()
absl_app.run(main)
......@@ -21,14 +21,14 @@ from __future__ import print_function
from tempfile import mkdtemp
import tensorflow as tf
from official.resnet import imagenet_main
from official.resnet.ctl import ctl_imagenet_main
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
from official.resnet.ctl import ctl_common
from official.resnet.ctl import ctl_imagenet_main
from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common
from official.utils.misc import keras_utils
from official.utils.testing import integration
# pylint: disable=ungrouped-imports
from tensorflow.python.eager import context
from tensorflow.python.platform import googletest
class CtlImagenetTest(googletest.TestCase):
......@@ -49,14 +49,14 @@ class CtlImagenetTest(googletest.TestCase):
@classmethod
def setUpClass(cls): # pylint: disable=invalid-name
super(CtlImagenetTest, cls).setUpClass()
imagenet_main.define_imagenet_flags()
keras_common.define_keras_flags()
ctl_common.define_ctl_flags()
def setUp(self):
super(CtlImagenetTest, self).setUp()
if not keras_utils.is_v2_0():
tf.compat.v1.enable_v2_behavior()
imagenet_main.NUM_IMAGES['validation'] = 4
imagenet_preprocessing.NUM_IMAGES['validation'] = 4
def tearDown(self):
super(CtlImagenetTest, self).tearDown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment