Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
......@@ -30,7 +30,7 @@ from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
TRAIN_TOP_1 = 'training_accuracy_top_1'
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
]
......@@ -39,8 +39,14 @@ class PiecewiseConstantDecayWithWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule):
"""Piecewise constant decay with warmup schedule."""
def __init__(self, batch_size, epoch_size, warmup_epochs, boundaries,
multipliers, compute_lr_on_cpu=True, name=None):
def __init__(self,
batch_size,
epoch_size,
warmup_epochs,
boundaries,
multipliers,
compute_lr_on_cpu=True,
name=None):
super(PiecewiseConstantDecayWithWarmup, self).__init__()
if len(boundaries) != len(multipliers) - 1:
raise ValueError('The length of boundaries must be 1 less than the '
......@@ -77,14 +83,16 @@ class PiecewiseConstantDecayWithWarmup(
def _get_learning_rate(self, step):
"""Compute learning rate at given step."""
with tf.name_scope('PiecewiseConstantDecayWithWarmup'):
def warmup_lr(step):
return self.rescaled_lr * (
tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
def piecewise_lr(step):
return tf.compat.v1.train.piecewise_constant(
step, self.step_boundaries, self.lr_values)
return tf.cond(step < self.warmup_steps,
lambda: warmup_lr(step),
return tf.compat.v1.train.piecewise_constant(step, self.step_boundaries,
self.lr_values)
return tf.cond(step < self.warmup_steps, lambda: warmup_lr(step),
lambda: piecewise_lr(step))
def get_config(self):
......@@ -104,10 +112,9 @@ def get_optimizer(learning_rate=0.1):
return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9)
def get_callbacks(
pruning_method=None,
enable_checkpoint_and_export=False,
model_dir=None):
def get_callbacks(pruning_method=None,
enable_checkpoint_and_export=False,
model_dir=None):
"""Returns common callbacks."""
time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
......@@ -117,23 +124,23 @@ def get_callbacks(
if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir,
profile_batch=FLAGS.profile_steps)
log_dir=FLAGS.model_dir, profile_batch=FLAGS.profile_steps)
callbacks.append(tensorboard_callback)
is_pruning_enabled = pruning_method is not None
if is_pruning_enabled:
callbacks.append(tfmot.sparsity.keras.UpdatePruningStep())
if model_dir is not None:
callbacks.append(tfmot.sparsity.keras.PruningSummaries(
log_dir=model_dir, profile_batch=0))
callbacks.append(
tfmot.sparsity.keras.PruningSummaries(
log_dir=model_dir, profile_batch=0))
if enable_checkpoint_and_export:
if model_dir is not None:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(ckpt_full_path,
save_weights_only=True))
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True))
return callbacks
......@@ -182,28 +189,32 @@ def build_stats(history, eval_output, callbacks):
return stats
def define_keras_flags(
dynamic_loss_scale=True,
model=False,
optimizer=False,
pretrained_filepath=False):
def define_keras_flags(dynamic_loss_scale=True,
model=False,
optimizer=False,
pretrained_filepath=False):
"""Define flags for Keras models."""
flags_core.define_base(clean=True, num_gpu=True, run_eagerly=True,
train_epochs=True, epochs_between_evals=True,
distribution_strategy=True)
flags_core.define_performance(num_parallel_calls=False,
synthetic_data=True,
dtype=True,
all_reduce_alg=True,
num_packs=True,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
dynamic_loss_scale=dynamic_loss_scale,
loss_scale=True,
fp16_implementation=True,
tf_data_experimental_slack=True,
enable_xla=True,
training_dataset_cache=True)
flags_core.define_base(
clean=True,
num_gpu=True,
run_eagerly=True,
train_epochs=True,
epochs_between_evals=True,
distribution_strategy=True)
flags_core.define_performance(
num_parallel_calls=False,
synthetic_data=True,
dtype=True,
all_reduce_alg=True,
num_packs=True,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
dynamic_loss_scale=dynamic_loss_scale,
loss_scale=True,
fp16_implementation=True,
tf_data_experimental_slack=True,
enable_xla=True,
training_dataset_cache=True)
flags_core.define_image()
flags_core.define_benchmark()
flags_core.define_distribution()
......@@ -214,23 +225,33 @@ def define_keras_flags(
# TODO(b/135607288): Remove this flag once we understand the root cause of
# slowdown when setting the learning phase in Keras backend.
flags.DEFINE_boolean(
name='set_learning_phase_to_train', default=True,
name='set_learning_phase_to_train',
default=True,
help='If skip eval, also set Keras learning phase to 1 (training).')
flags.DEFINE_boolean(
name='explicit_gpu_placement', default=False,
name='explicit_gpu_placement',
default=False,
help='If not using distribution strategy, explicitly set device scope '
'for the Keras training loop.')
flags.DEFINE_boolean(name='use_trivial_model', default=False,
help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(name='report_accuracy_metrics', default=True,
help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(name='use_tensor_lr', default=True,
help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean(
name='enable_tensorboard', default=False,
name='use_trivial_model',
default=False,
help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(
name='report_accuracy_metrics',
default=True,
help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(
name='use_tensor_lr',
default=True,
help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean(
name='enable_tensorboard',
default=False,
help='Whether to enable Tensorboard callback.')
flags.DEFINE_string(
name='profile_steps', default=None,
name='profile_steps',
default=None,
help='Save profiling data to model dir at given range of global steps. The '
'value must be a comma separated pair of positive integers, specifying '
'the first and last step to profile. For example, "--profile_steps=2,4" '
......@@ -238,21 +259,24 @@ def define_keras_flags(
'Note that profiler has a non-trivial performance overhead, and the '
'output file can be gigantic if profiling many steps.')
flags.DEFINE_integer(
name='train_steps', default=None,
name='train_steps',
default=None,
help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. This flag will be '
'ignored if train_epochs is set to be larger than 1. ')
flags.DEFINE_boolean(
name='batchnorm_spatial_persistent', default=True,
name='batchnorm_spatial_persistent',
default=True,
help='Enable the spacial persistent mode for CuDNN batch norm kernel.')
flags.DEFINE_boolean(
name='enable_get_next_as_optional', default=False,
name='enable_get_next_as_optional',
default=False,
help='Enable get_next_as_optional behavior in DistributedIterator.')
flags.DEFINE_boolean(
name='enable_checkpoint_and_export', default=False,
name='enable_checkpoint_and_export',
default=False,
help='Whether to enable a checkpoint callback and export the savedmodel.')
flags.DEFINE_string(
name='tpu', default='', help='TPU address to connect to.')
flags.DEFINE_string(name='tpu', default='', help='TPU address to connect to.')
flags.DEFINE_integer(
name='steps_per_loop',
default=None,
......@@ -270,20 +294,20 @@ def define_keras_flags(
flags.DEFINE_string('model', 'resnet50_v1.5',
'Name of model preset. (mobilenet, resnet50_v1.5)')
if optimizer:
flags.DEFINE_string('optimizer', 'resnet50_default',
'Name of optimizer preset. '
'(mobilenet_default, resnet50_default)')
flags.DEFINE_string(
'optimizer', 'resnet50_default', 'Name of optimizer preset. '
'(mobilenet_default, resnet50_default)')
# TODO(kimjaehong): Replace as general hyper-params not only for mobilenet.
flags.DEFINE_float('initial_learning_rate_per_sample', 0.00007,
'Initial value of learning rate per sample for '
'mobilenet_default.')
flags.DEFINE_float(
'initial_learning_rate_per_sample', 0.00007,
'Initial value of learning rate per sample for '
'mobilenet_default.')
flags.DEFINE_float('lr_decay_factor', 0.94,
'Learning rate decay factor for mobilenet_default.')
flags.DEFINE_float('num_epochs_per_decay', 2.5,
'Number of epochs per decay for mobilenet_default.')
if pretrained_filepath:
flags.DEFINE_string('pretrained_filepath', '',
'Pretrained file path.')
flags.DEFINE_string('pretrained_filepath', '', 'Pretrained file path.')
def get_synth_data(height, width, num_channels, num_classes, dtype):
......@@ -317,23 +341,24 @@ def get_synth_data(height, width, num_channels, num_classes, dtype):
def define_pruning_flags():
"""Define flags for pruning methods."""
flags.DEFINE_string('pruning_method', None,
'Pruning method.'
'None (no pruning) or polynomial_decay.')
flags.DEFINE_string(
'pruning_method', None, 'Pruning method.'
'None (no pruning) or polynomial_decay.')
flags.DEFINE_float('pruning_initial_sparsity', 0.0,
'Initial sparsity for pruning.')
flags.DEFINE_float('pruning_final_sparsity', 0.5,
'Final sparsity for pruning.')
flags.DEFINE_integer('pruning_begin_step', 0,
'Begin step for pruning.')
flags.DEFINE_integer('pruning_end_step', 100000,
'End step for pruning.')
flags.DEFINE_integer('pruning_frequency', 100,
'Frequency for pruning.')
flags.DEFINE_integer('pruning_begin_step', 0, 'Begin step for pruning.')
flags.DEFINE_integer('pruning_end_step', 100000, 'End step for pruning.')
flags.DEFINE_integer('pruning_frequency', 100, 'Frequency for pruning.')
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32, drop_remainder=True):
def get_synth_input_fn(height,
width,
num_channels,
num_classes,
dtype=tf.float32,
drop_remainder=True):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
......@@ -355,14 +380,16 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration.
"""
# pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data."""
inputs, labels = get_synth_data(height=height,
width=width,
num_channels=num_channels,
num_classes=num_classes,
dtype=dtype)
inputs, labels = get_synth_data(
height=height,
width=width,
num_channels=num_channels,
num_classes=num_classes,
dtype=dtype)
# Cast to float32 for Keras model.
labels = tf.cast(labels, dtype=tf.float32)
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
......
......@@ -36,6 +36,7 @@ from __future__ import division
from __future__ import print_function
import os
from absl import logging
import tensorflow as tf
......@@ -78,17 +79,17 @@ def process_record_dataset(dataset,
is_training: A boolean denoting whether the input is for training.
batch_size: The number of samples per batch.
shuffle_buffer: The buffer size to use when shuffling records. A larger
value results in better randomness, but smaller values reduce startup
time and use less memory.
value results in better randomness, but smaller values reduce startup time
and use less memory.
parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair.
dtype: Data type to use for images/features.
datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation.
datasets_num_private_threads: Number of threads for a private threadpool
created for all datasets computation.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
tf_data_experimental_slack: Whether to enable tf.data's
`experimental_slack` option.
tf_data_experimental_slack: Whether to enable tf.data's `experimental_slack`
option.
Returns:
Dataset of (image, label) pairs ready for iteration.
......@@ -99,8 +100,8 @@ def process_record_dataset(dataset,
options.experimental_threading.private_threadpool_size = (
datasets_num_private_threads)
dataset = dataset.with_options(options)
logging.info(
'datasets_num_private_threads: %s', datasets_num_private_threads)
logging.info('datasets_num_private_threads: %s',
datasets_num_private_threads)
if is_training:
# Shuffles records before repeating to respect epoch boundaries.
......@@ -134,11 +135,13 @@ def get_filenames(is_training, data_dir):
if is_training:
return [
os.path.join(data_dir, 'train-%05d-of-01024' % i)
for i in range(_NUM_TRAIN_FILES)]
for i in range(_NUM_TRAIN_FILES)
]
else:
return [
os.path.join(data_dir, 'validation-%05d-of-00128' % i)
for i in range(128)]
for i in range(128)
]
def parse_example_proto(example_serialized):
......@@ -165,8 +168,8 @@ def parse_example_proto(example_serialized):
image/encoded: <JPEG encoded string>
Args:
example_serialized: scalar Tensor tf.string containing a serialized
Example protocol buffer.
example_serialized: scalar Tensor tf.string containing a serialized Example
protocol buffer.
Returns:
image_buffer: Tensor tf.string containing the contents of a JPEG file.
......@@ -177,22 +180,24 @@ def parse_example_proto(example_serialized):
"""
# Dense features in Example proto.
feature_map = {
'image/encoded': tf.io.FixedLenFeature([], dtype=tf.string,
default_value=''),
'image/class/label': tf.io.FixedLenFeature([], dtype=tf.int64,
default_value=-1),
'image/class/text': tf.io.FixedLenFeature([], dtype=tf.string,
default_value=''),
'image/encoded':
tf.io.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/class/label':
tf.io.FixedLenFeature([], dtype=tf.int64, default_value=-1),
'image/class/text':
tf.io.FixedLenFeature([], dtype=tf.string, default_value=''),
}
sparse_float32 = tf.io.VarLenFeature(dtype=tf.float32)
# Sparse features in Example proto.
feature_map.update(
{k: sparse_float32 for k in [
feature_map.update({
k: sparse_float32 for k in [
'image/object/bbox/xmin', 'image/object/bbox/ymin',
'image/object/bbox/xmax', 'image/object/bbox/ymax']})
'image/object/bbox/xmax', 'image/object/bbox/ymax'
]
})
features = tf.io.parse_single_example(serialized=example_serialized,
features=feature_map)
features = tf.io.parse_single_example(
serialized=example_serialized, features=feature_map)
label = tf.cast(features['image/class/label'], dtype=tf.int32)
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
......@@ -218,8 +223,8 @@ def parse_record(raw_record, is_training, dtype):
through preprocessing steps (cropping, flipping, and so on).
Args:
raw_record: scalar Tensor tf.string containing a serialized
Example protocol buffer.
raw_record: scalar Tensor tf.string containing a serialized Example protocol
buffer.
is_training: A boolean denoting whether the input is for training.
dtype: data type to use for images/features.
......@@ -240,8 +245,9 @@ def parse_record(raw_record, is_training, dtype):
# Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32)
label = tf.cast(
tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32)
return image, label
......@@ -262,12 +268,14 @@ def get_parse_record_fn(use_keras_image_data_format=False):
Returns:
Function to use for parsing the records.
"""
def parse_record_fn(raw_record, is_training, dtype):
image, label = parse_record(raw_record, is_training, dtype)
if use_keras_image_data_format:
if tf.keras.backend.image_data_format() == 'channels_first':
image = tf.transpose(image, perm=[2, 0, 1])
return image, label
return parse_record_fn
......@@ -295,11 +303,11 @@ def input_fn(is_training,
`tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
tf_data_experimental_slack: Whether to enable tf.data's
`experimental_slack` option.
tf_data_experimental_slack: Whether to enable tf.data's `experimental_slack`
option.
training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in
remote storage and can fit into worker memory.
Typically used to improve training performance when training data is in
remote storage and can fit into worker memory.
filenames: Optional field for providing the file names of the TFRecords.
Returns:
......@@ -357,8 +365,8 @@ def _decode_crop_and_flip(image_buffer, bbox, num_channels):
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
where each coordinate is [0, 1) and the coordinates are arranged as [ymin,
xmin, ymax, xmax].
num_channels: Integer depth of the image buffer for decoding.
Returns:
......@@ -414,8 +422,8 @@ def _central_crop(image, crop_height, crop_width):
crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = (width - crop_width)
crop_left = amount_to_be_cropped_w // 2
return tf.slice(
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
return tf.slice(image, [crop_top, crop_left, 0],
[crop_height, crop_width, -1])
def _mean_image_subtraction(image, means, num_channels):
......@@ -463,8 +471,8 @@ def _smallest_size_at_least(height, width, resize_min):
Args:
height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
resize_min: A python integer or scalar `Tensor` indicating the size of the
smallest side after resize.
Returns:
new_height: an int32 scalar tensor indicating the new height.
......@@ -490,8 +498,8 @@ def _aspect_preserving_resize(image, resize_min):
Args:
image: A 3-D image `Tensor`.
resize_min: A python integer or scalar `Tensor` indicating the size of
the smallest side after resize.
resize_min: A python integer or scalar `Tensor` indicating the size of the
smallest side after resize.
Returns:
resized_image: A 3-D tensor containing the resized image.
......@@ -520,12 +528,17 @@ def _resize_image(image, height, width):
dimensions have the shape [height, width].
"""
return tf.compat.v1.image.resize(
image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
image, [height, width],
method=tf.image.ResizeMethod.BILINEAR,
align_corners=False)
def preprocess_image(image_buffer, bbox, output_height, output_width,
num_channels, is_training=False):
def preprocess_image(image_buffer,
bbox,
output_height,
output_width,
num_channels,
is_training=False):
"""Preprocesses the given image.
Preprocessing includes decoding, cropping, and resizing for both training
......@@ -535,8 +548,8 @@ def preprocess_image(image_buffer, bbox, output_height, output_width,
Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
where each coordinate is [0, 1) and the coordinates are arranged as [ymin,
xmin, ymax, xmax].
output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing.
num_channels: Integer depth of the image buffer for decoding.
......
......@@ -52,7 +52,4 @@ class ResNetModelConfig(base_configs.ModelConfig):
boundaries=[30, 60, 80],
warmup_epochs=5,
scale_by_batch_size=1. / 256.,
multipliers=[0.1 / 256,
0.01 / 256,
0.001 / 256,
0.0001 / 256]))
multipliers=[0.1 / 256, 0.01 / 256, 0.001 / 256, 0.0001 / 256]))
......@@ -17,6 +17,7 @@
import math
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
import os
# Import libraries
from absl import app
from absl import flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment