Unverified Commit 1921a3b5 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Merged commit includes the following changes: (#7365)

261339941  by haoyuzhang<haoyuzhang@google.com>:

    Own library functions in Keras ResNet models, and remove dependencies on v1 Estimator version of ResNet models.

    Most dependencies that the Keras version has are related to data input pipelines. Created dedicated files (cifar_preprocessing.py, imagenet_preprocessing.py) to collect all logic handling Cifar and ImageNet data input function.

--
261339166  by haoyuzhang<haoyuzhang@google.com>:

    Internal change

261317601  by akuegel<akuegel@google.com>:

    Internal change

261218818  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Internal change

PiperOrigin-RevId: 261339941
parent 13e7c85d
...@@ -799,9 +799,9 @@ class Transformer(tf.keras.layers.Layer): ...@@ -799,9 +799,9 @@ class Transformer(tf.keras.layers.Layer):
name=("layer_%d" % i))) name=("layer_%d" % i)))
super(Transformer, self).build(unused_input_shapes) super(Transformer, self).build(unused_input_shapes)
def __call__(self, input_tensor, attention_mask=None): def __call__(self, input_tensor, attention_mask=None, **kwargs):
inputs = pack_inputs([input_tensor, attention_mask]) inputs = pack_inputs([input_tensor, attention_mask])
return super(Transformer, self).__call__(inputs=inputs) return super(Transformer, self).__call__(inputs=inputs, **kwargs)
def call(self, inputs): def call(self, inputs):
"""Implements call() for the layer.""" """Implements call() for the layer."""
......
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides utilities to Cifar-10 dataset."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import tensorflow as tf
from official.resnet.keras import imagenet_preprocessing
HEIGHT = 32
WIDTH = 32
NUM_CHANNELS = 3
_DEFAULT_IMAGE_BYTES = HEIGHT * WIDTH * NUM_CHANNELS
# The record is the image plus a one-byte label
_RECORD_BYTES = _DEFAULT_IMAGE_BYTES + 1
# TODO(tobyboyd): Change to best practice 45K(train)/5K(val)/10K(test) splits.
NUM_IMAGES = {
'train': 50000,
'validation': 10000,
}
_NUM_DATA_FILES = 5
NUM_CLASSES = 10
def parse_record(raw_record, is_training, dtype):
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
This method converts the label to one hot to fit the loss function.
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
dtype: Data type to use for input images.
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
# Convert bytes to a vector of uint8 that is record_bytes long.
record_vector = tf.io.decode_raw(raw_record, tf.uint8)
# The first byte represents the label, which we convert from uint8 to int32
# and then to one-hot.
label = tf.cast(record_vector[0], tf.int32)
# The remaining bytes after the label represent the image, which we reshape
# from [depth * height * width] to [depth, height, width].
depth_major = tf.reshape(record_vector[1:_RECORD_BYTES],
[NUM_CHANNELS, HEIGHT, WIDTH])
# Convert from [depth, height, width] to [height, width, depth], and cast as
# float32.
image = tf.cast(tf.transpose(a=depth_major, perm=[1, 2, 0]), tf.float32)
image = preprocess_image(image, is_training)
image = tf.cast(image, dtype)
# TODO(haoyuzhang,hongkuny,tobyboyd): Remove or replace the use of V1 API
label = tf.compat.v1.sparse_to_dense(label, (NUM_CLASSES,), 1)
return image, label
def preprocess_image(image, is_training):
"""Preprocess a single image of layout [height, width, depth]."""
if is_training:
# Resize the image to add four extra pixels on each side.
image = tf.image.resize_with_crop_or_pad(
image, HEIGHT + 8, WIDTH + 8)
# Randomly crop a [HEIGHT, WIDTH] section of the image.
image = tf.image.random_crop(image, [HEIGHT, WIDTH, NUM_CHANNELS])
# Randomly flip the image horizontally.
image = tf.image.random_flip_left_right(image)
# Subtract off the mean and divide by the variance of the pixels.
image = tf.image.per_image_standardization(image)
return image
def get_filenames(is_training, data_dir):
"""Returns a list of filenames."""
assert tf.io.gfile.exists(data_dir), (
'Run cifar10_download_and_extract.py first to download and extract the '
'CIFAR-10 data.')
if is_training:
return [
os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, _NUM_DATA_FILES + 1)
]
else:
return [os.path.join(data_dir, 'test_batch.bin')]
def input_fn(is_training,
data_dir,
batch_size,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
parse_record_fn=parse_record,
input_context=None,
drop_remainder=False):
"""Input function which provides batches for train or eval.
Args:
is_training: A boolean denoting whether the input is for training.
data_dir: The directory containing the input data.
batch_size: The number of samples per batch.
num_epochs: The number of epochs to repeat the dataset.
dtype: Data type to use for images/features
datasets_num_private_threads: Number of private threads for tf.data.
parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns:
A dataset that can be used for iteration.
"""
filenames = get_filenames(is_training, data_dir)
dataset = tf.data.FixedLengthRecordDataset(filenames, _RECORD_BYTES)
if input_context:
logging.info(
'Sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
input_context.input_pipeline_id, input_context.num_input_pipelines)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return imagenet_preprocessing.process_record_dataset(
dataset=dataset,
is_training=is_training,
batch_size=batch_size,
shuffle_buffer=NUM_IMAGES['train'],
parse_record_fn=parse_record_fn,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
drop_remainder=drop_remainder
)
This diff is collapsed.
...@@ -288,6 +288,17 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -288,6 +288,17 @@ class Resnet56KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 128 FLAGS.batch_size = 128
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_1_gpu_xla(self):
"""Test 1 gpu with xla enabled."""
self._setup()
FLAGS.num_gpus = 1
FLAGS.enable_eager = True
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'default'
FLAGS.model_dir = self._get_model_dir('benchmark_1_gpu_xla')
FLAGS.batch_size = 128
self._run_and_report_benchmark()
def benchmark_1_gpu_force_v2(self): def benchmark_1_gpu_force_v2(self):
"""Test 1 gpu using forced v2 execution path.""" """Test 1 gpu using forced v2 execution path."""
self._setup() self._setup()
......
...@@ -20,9 +20,9 @@ from __future__ import print_function ...@@ -20,9 +20,9 @@ from __future__ import print_function
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf
from official.resnet import cifar10_main as cifar_main from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet_cifar_model from official.resnet.keras import resnet_cifar_model
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -65,28 +65,6 @@ def learning_rate_schedule(current_epoch, ...@@ -65,28 +65,6 @@ def learning_rate_schedule(current_epoch,
return learning_rate return learning_rate
def parse_record_keras(raw_record, is_training, dtype):
"""Parses a record containing a training example of an image.
The input record is parsed into a label and image, and the image is passed
through preprocessing steps (cropping, flipping, and so on).
This method converts the label to one hot to fit the loss function.
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
is_training: A boolean denoting whether the input is for training.
dtype: Data type to use for input images.
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
"""
image, label = cifar_main.parse_record(raw_record, is_training, dtype)
label = tf.compat.v1.sparse_to_dense(label, (cifar_main.NUM_CLASSES,), 1)
return image, label
def run(flags_obj): def run(flags_obj):
"""Run ResNet Cifar-10 training and eval loop using native Keras APIs. """Run ResNet Cifar-10 training and eval loop using native Keras APIs.
...@@ -141,22 +119,22 @@ def run(flags_obj): ...@@ -141,22 +119,22 @@ def run(flags_obj):
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data() distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT, height=cifar_preprocessing.HEIGHT,
width=cifar_main.WIDTH, width=cifar_preprocessing.WIDTH,
num_channels=cifar_main.NUM_CHANNELS, num_channels=cifar_preprocessing.NUM_CHANNELS,
num_classes=cifar_main.NUM_CLASSES, num_classes=cifar_preprocessing.NUM_CLASSES,
dtype=flags_core.get_tf_dtype(flags_obj), dtype=flags_core.get_tf_dtype(flags_obj),
drop_remainder=True) drop_remainder=True)
else: else:
distribution_utils.undo_set_up_synthetic_data() distribution_utils.undo_set_up_synthetic_data()
input_fn = cifar_main.input_fn input_fn = cifar_preprocessing.input_fn
train_input_dataset = input_fn( train_input_dataset = input_fn(
is_training=True, is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=cifar_preprocessing.parse_record,
datasets_num_private_threads=flags_obj.datasets_num_private_threads, datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype, dtype=dtype,
# Setting drop_remainder to avoid the partial batch logic in normalization # Setting drop_remainder to avoid the partial batch logic in normalization
...@@ -171,11 +149,11 @@ def run(flags_obj): ...@@ -171,11 +149,11 @@ def run(flags_obj):
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras) parse_record_fn=cifar_preprocessing.parse_record)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
model = resnet_cifar_model.resnet56(classes=cifar_main.NUM_CLASSES) model = resnet_cifar_model.resnet56(classes=cifar_preprocessing.NUM_CLASSES)
model.compile( model.compile(
loss='categorical_crossentropy', loss='categorical_crossentropy',
...@@ -186,16 +164,16 @@ def run(flags_obj): ...@@ -186,16 +164,16 @@ def run(flags_obj):
experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
callbacks = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
learning_rate_schedule, cifar_main.NUM_IMAGES['train']) learning_rate_schedule, cifar_preprocessing.NUM_IMAGES['train'])
train_steps = cifar_main.NUM_IMAGES['train'] // flags_obj.batch_size train_steps = cifar_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size
train_epochs = flags_obj.train_epochs train_epochs = flags_obj.train_epochs
if flags_obj.train_steps: if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps) train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1 train_epochs = 1
num_eval_steps = (cifar_main.NUM_IMAGES['validation'] // num_eval_steps = (cifar_preprocessing.NUM_IMAGES['validation'] //
flags_obj.batch_size) flags_obj.batch_size)
validation_data = eval_input_dataset validation_data = eval_input_dataset
......
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
from tempfile import mkdtemp from tempfile import mkdtemp
import tensorflow as tf import tensorflow as tf
from official.resnet import cifar10_main from official.resnet.keras import cifar_preprocessing
from official.resnet.keras import keras_cifar_main from official.resnet.keras import keras_cifar_main
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
...@@ -53,7 +53,7 @@ class KerasCifarTest(googletest.TestCase): ...@@ -53,7 +53,7 @@ class KerasCifarTest(googletest.TestCase):
def setUp(self): def setUp(self):
super(KerasCifarTest, self).setUp() super(KerasCifarTest, self).setUp()
cifar10_main.NUM_IMAGES["validation"] = 4 cifar_preprocessing.NUM_IMAGES["validation"] = 4
def tearDown(self): def tearDown(self):
super(KerasCifarTest, self).tearDown() super(KerasCifarTest, self).tearDown()
......
...@@ -20,9 +20,10 @@ from __future__ import print_function ...@@ -20,9 +20,10 @@ from __future__ import print_function
from absl import app as absl_app from absl import app as absl_app
from absl import flags from absl import flags
from absl import logging
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from official.resnet import imagenet_main from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model from official.resnet.keras import resnet_model
from official.resnet.keras import trivial_model from official.resnet.keras import trivial_model
...@@ -70,17 +71,6 @@ def learning_rate_schedule(current_epoch, ...@@ -70,17 +71,6 @@ def learning_rate_schedule(current_epoch,
return learning_rate return learning_rate
def parse_record_keras(raw_record, is_training, dtype):
"""Adjust the shape of label."""
image, label = imagenet_main.parse_record(raw_record, is_training, dtype)
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32)
return image, label
def run(flags_obj): def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs. """Run ResNet ImageNet training and eval loop using native Keras APIs.
...@@ -138,15 +128,15 @@ def run(flags_obj): ...@@ -138,15 +128,15 @@ def run(flags_obj):
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
distribution_utils.set_up_synthetic_data() distribution_utils.set_up_synthetic_data()
input_fn = keras_common.get_synth_input_fn( input_fn = keras_common.get_synth_input_fn(
height=imagenet_main.DEFAULT_IMAGE_SIZE, height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
width=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS, num_channels=imagenet_preprocessing.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_preprocessing.NUM_CLASSES,
dtype=dtype, dtype=dtype,
drop_remainder=True) drop_remainder=True)
else: else:
distribution_utils.undo_set_up_synthetic_data() distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn input_fn = imagenet_preprocessing.input_fn
# When `enable_xla` is True, we always drop the remainder of the batches # When `enable_xla` is True, we always drop the remainder of the batches
# in the dataset, as XLA-GPU doesn't support dynamic shapes. # in the dataset, as XLA-GPU doesn't support dynamic shapes.
...@@ -157,7 +147,7 @@ def run(flags_obj): ...@@ -157,7 +147,7 @@ def run(flags_obj):
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=imagenet_preprocessing.parse_record,
datasets_num_private_threads=flags_obj.datasets_num_private_threads, datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype, dtype=dtype,
drop_remainder=drop_remainder, drop_remainder=drop_remainder,
...@@ -171,7 +161,7 @@ def run(flags_obj): ...@@ -171,7 +161,7 @@ def run(flags_obj):
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=imagenet_preprocessing.parse_record,
dtype=dtype, dtype=dtype,
drop_remainder=drop_remainder) drop_remainder=drop_remainder)
...@@ -179,7 +169,7 @@ def run(flags_obj): ...@@ -179,7 +169,7 @@ def run(flags_obj):
if flags_obj.use_tensor_lr: if flags_obj.use_tensor_lr:
lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup( lr_schedule = keras_common.PiecewiseConstantDecayWithWarmup(
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
epoch_size=imagenet_main.NUM_IMAGES['train'], epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=LR_SCHEDULE[0][1], warmup_epochs=LR_SCHEDULE[0][1],
boundaries=list(p[1] for p in LR_SCHEDULE[1:]), boundaries=list(p[1] for p in LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in LR_SCHEDULE), multipliers=list(p[0] for p in LR_SCHEDULE),
...@@ -195,11 +185,11 @@ def run(flags_obj): ...@@ -195,11 +185,11 @@ def run(flags_obj):
default_for_fp16=128)) default_for_fp16=128))
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype) model = trivial_model.trivial_model(
imagenet_preprocessing.NUM_CLASSES, dtype)
else: else:
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_preprocessing.NUM_CLASSES, dtype=dtype)
dtype=dtype)
model.compile( model.compile(
loss='sparse_categorical_crossentropy', loss='sparse_categorical_crossentropy',
...@@ -210,17 +200,18 @@ def run(flags_obj): ...@@ -210,17 +200,18 @@ def run(flags_obj):
experimental_run_tf_function=flags_obj.force_v2_in_keras_compile) experimental_run_tf_function=flags_obj.force_v2_in_keras_compile)
callbacks = keras_common.get_callbacks( callbacks = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train']) learning_rate_schedule, imagenet_preprocessing.NUM_IMAGES['train'])
train_steps = imagenet_main.NUM_IMAGES['train'] // flags_obj.batch_size train_steps = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs train_epochs = flags_obj.train_epochs
if flags_obj.train_steps: if flags_obj.train_steps:
train_steps = min(flags_obj.train_steps, train_steps) train_steps = min(flags_obj.train_steps, train_steps)
train_epochs = 1 train_epochs = 1
num_eval_steps = (imagenet_main.NUM_IMAGES['validation'] // num_eval_steps = (
flags_obj.batch_size) imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
validation_data = eval_input_dataset validation_data = eval_input_dataset
if flags_obj.skip_eval: if flags_obj.skip_eval:
...@@ -271,10 +262,10 @@ def main(_): ...@@ -271,10 +262,10 @@ def main(_):
model_helpers.apply_clean(flags.FLAGS) model_helpers.apply_clean(flags.FLAGS)
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
stats = run(flags.FLAGS) stats = run(flags.FLAGS)
tf.compat.v1.logging.info('Run stats:\n%s' % stats) logging.info('Run stats:\n%s', stats)
if __name__ == '__main__': if __name__ == '__main__':
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) logging.set_verbosity(logging.INFO)
define_imagenet_keras_flags() define_imagenet_keras_flags()
absl_app.run(main) absl_app.run(main)
...@@ -21,7 +21,7 @@ from __future__ import print_function ...@@ -21,7 +21,7 @@ from __future__ import print_function
from tempfile import mkdtemp from tempfile import mkdtemp
import tensorflow as tf import tensorflow as tf
from official.resnet import imagenet_main from official.resnet.keras import imagenet_preprocessing
from official.resnet.keras import keras_imagenet_main from official.resnet.keras import keras_imagenet_main
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
from official.utils.testing import integration from official.utils.testing import integration
...@@ -52,7 +52,7 @@ class KerasImagenetTest(googletest.TestCase): ...@@ -52,7 +52,7 @@ class KerasImagenetTest(googletest.TestCase):
def setUp(self): def setUp(self):
super(KerasImagenetTest, self).setUp() super(KerasImagenetTest, self).setUp()
imagenet_main.NUM_IMAGES["validation"] = 4 imagenet_preprocessing.NUM_IMAGES["validation"] = 4
def tearDown(self): def tearDown(self):
super(KerasImagenetTest, self).tearDown() super(KerasImagenetTest, self).tearDown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment