".github/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "07ab4d4a2df2ca28cfa82197e6a692a0f69c1d02"
Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
...@@ -30,7 +30,7 @@ from official.utils.misc import keras_utils ...@@ -30,7 +30,7 @@ from official.utils.misc import keras_utils
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
BASE_LEARNING_RATE = 0.1 # This matches Jing's version. BASE_LEARNING_RATE = 0.1 # This matches Jing's version.
TRAIN_TOP_1 = 'training_accuracy_top_1' TRAIN_TOP_1 = 'training_accuracy_top_1'
LR_SCHEDULE = [ # (multiplier, epoch to start) tuples LR_SCHEDULE = [ # (multiplier, epoch to start) tuples
(1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80) (1.0, 5), (0.1, 30), (0.01, 60), (0.001, 80)
] ]
...@@ -39,8 +39,14 @@ class PiecewiseConstantDecayWithWarmup( ...@@ -39,8 +39,14 @@ class PiecewiseConstantDecayWithWarmup(
tf.keras.optimizers.schedules.LearningRateSchedule): tf.keras.optimizers.schedules.LearningRateSchedule):
"""Piecewise constant decay with warmup schedule.""" """Piecewise constant decay with warmup schedule."""
def __init__(self, batch_size, epoch_size, warmup_epochs, boundaries, def __init__(self,
multipliers, compute_lr_on_cpu=True, name=None): batch_size,
epoch_size,
warmup_epochs,
boundaries,
multipliers,
compute_lr_on_cpu=True,
name=None):
super(PiecewiseConstantDecayWithWarmup, self).__init__() super(PiecewiseConstantDecayWithWarmup, self).__init__()
if len(boundaries) != len(multipliers) - 1: if len(boundaries) != len(multipliers) - 1:
raise ValueError('The length of boundaries must be 1 less than the ' raise ValueError('The length of boundaries must be 1 less than the '
...@@ -77,14 +83,16 @@ class PiecewiseConstantDecayWithWarmup( ...@@ -77,14 +83,16 @@ class PiecewiseConstantDecayWithWarmup(
def _get_learning_rate(self, step): def _get_learning_rate(self, step):
"""Compute learning rate at given step.""" """Compute learning rate at given step."""
with tf.name_scope('PiecewiseConstantDecayWithWarmup'): with tf.name_scope('PiecewiseConstantDecayWithWarmup'):
def warmup_lr(step): def warmup_lr(step):
return self.rescaled_lr * ( return self.rescaled_lr * (
tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32)) tf.cast(step, tf.float32) / tf.cast(self.warmup_steps, tf.float32))
def piecewise_lr(step): def piecewise_lr(step):
return tf.compat.v1.train.piecewise_constant( return tf.compat.v1.train.piecewise_constant(step, self.step_boundaries,
step, self.step_boundaries, self.lr_values) self.lr_values)
return tf.cond(step < self.warmup_steps,
lambda: warmup_lr(step), return tf.cond(step < self.warmup_steps, lambda: warmup_lr(step),
lambda: piecewise_lr(step)) lambda: piecewise_lr(step))
def get_config(self): def get_config(self):
...@@ -104,10 +112,9 @@ def get_optimizer(learning_rate=0.1): ...@@ -104,10 +112,9 @@ def get_optimizer(learning_rate=0.1):
return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9) return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9)
def get_callbacks( def get_callbacks(pruning_method=None,
pruning_method=None, enable_checkpoint_and_export=False,
enable_checkpoint_and_export=False, model_dir=None):
model_dir=None):
"""Returns common callbacks.""" """Returns common callbacks."""
time_callback = keras_utils.TimeHistory( time_callback = keras_utils.TimeHistory(
FLAGS.batch_size, FLAGS.batch_size,
...@@ -117,23 +124,23 @@ def get_callbacks( ...@@ -117,23 +124,23 @@ def get_callbacks(
if FLAGS.enable_tensorboard: if FLAGS.enable_tensorboard:
tensorboard_callback = tf.keras.callbacks.TensorBoard( tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir, log_dir=FLAGS.model_dir, profile_batch=FLAGS.profile_steps)
profile_batch=FLAGS.profile_steps)
callbacks.append(tensorboard_callback) callbacks.append(tensorboard_callback)
is_pruning_enabled = pruning_method is not None is_pruning_enabled = pruning_method is not None
if is_pruning_enabled: if is_pruning_enabled:
callbacks.append(tfmot.sparsity.keras.UpdatePruningStep()) callbacks.append(tfmot.sparsity.keras.UpdatePruningStep())
if model_dir is not None: if model_dir is not None:
callbacks.append(tfmot.sparsity.keras.PruningSummaries( callbacks.append(
log_dir=model_dir, profile_batch=0)) tfmot.sparsity.keras.PruningSummaries(
log_dir=model_dir, profile_batch=0))
if enable_checkpoint_and_export: if enable_checkpoint_and_export:
if model_dir is not None: if model_dir is not None:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}') ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append( callbacks.append(
tf.keras.callbacks.ModelCheckpoint(ckpt_full_path, tf.keras.callbacks.ModelCheckpoint(
save_weights_only=True)) ckpt_full_path, save_weights_only=True))
return callbacks return callbacks
...@@ -182,28 +189,32 @@ def build_stats(history, eval_output, callbacks): ...@@ -182,28 +189,32 @@ def build_stats(history, eval_output, callbacks):
return stats return stats
def define_keras_flags( def define_keras_flags(dynamic_loss_scale=True,
dynamic_loss_scale=True, model=False,
model=False, optimizer=False,
optimizer=False, pretrained_filepath=False):
pretrained_filepath=False):
"""Define flags for Keras models.""" """Define flags for Keras models."""
flags_core.define_base(clean=True, num_gpu=True, run_eagerly=True, flags_core.define_base(
train_epochs=True, epochs_between_evals=True, clean=True,
distribution_strategy=True) num_gpu=True,
flags_core.define_performance(num_parallel_calls=False, run_eagerly=True,
synthetic_data=True, train_epochs=True,
dtype=True, epochs_between_evals=True,
all_reduce_alg=True, distribution_strategy=True)
num_packs=True, flags_core.define_performance(
tf_gpu_thread_mode=True, num_parallel_calls=False,
datasets_num_private_threads=True, synthetic_data=True,
dynamic_loss_scale=dynamic_loss_scale, dtype=True,
loss_scale=True, all_reduce_alg=True,
fp16_implementation=True, num_packs=True,
tf_data_experimental_slack=True, tf_gpu_thread_mode=True,
enable_xla=True, datasets_num_private_threads=True,
training_dataset_cache=True) dynamic_loss_scale=dynamic_loss_scale,
loss_scale=True,
fp16_implementation=True,
tf_data_experimental_slack=True,
enable_xla=True,
training_dataset_cache=True)
flags_core.define_image() flags_core.define_image()
flags_core.define_benchmark() flags_core.define_benchmark()
flags_core.define_distribution() flags_core.define_distribution()
...@@ -214,23 +225,33 @@ def define_keras_flags( ...@@ -214,23 +225,33 @@ def define_keras_flags(
# TODO(b/135607288): Remove this flag once we understand the root cause of # TODO(b/135607288): Remove this flag once we understand the root cause of
# slowdown when setting the learning phase in Keras backend. # slowdown when setting the learning phase in Keras backend.
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='set_learning_phase_to_train', default=True, name='set_learning_phase_to_train',
default=True,
help='If skip eval, also set Keras learning phase to 1 (training).') help='If skip eval, also set Keras learning phase to 1 (training).')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='explicit_gpu_placement', default=False, name='explicit_gpu_placement',
default=False,
help='If not using distribution strategy, explicitly set device scope ' help='If not using distribution strategy, explicitly set device scope '
'for the Keras training loop.') 'for the Keras training loop.')
flags.DEFINE_boolean(name='use_trivial_model', default=False,
help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(name='report_accuracy_metrics', default=True,
help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(name='use_tensor_lr', default=True,
help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_tensorboard', default=False, name='use_trivial_model',
default=False,
help='Whether to use a trivial Keras model.')
flags.DEFINE_boolean(
name='report_accuracy_metrics',
default=True,
help='Report metrics during training and evaluation.')
flags.DEFINE_boolean(
name='use_tensor_lr',
default=True,
help='Use learning rate tensor instead of a callback.')
flags.DEFINE_boolean(
name='enable_tensorboard',
default=False,
help='Whether to enable Tensorboard callback.') help='Whether to enable Tensorboard callback.')
flags.DEFINE_string( flags.DEFINE_string(
name='profile_steps', default=None, name='profile_steps',
default=None,
help='Save profiling data to model dir at given range of global steps. The ' help='Save profiling data to model dir at given range of global steps. The '
'value must be a comma separated pair of positive integers, specifying ' 'value must be a comma separated pair of positive integers, specifying '
'the first and last step to profile. For example, "--profile_steps=2,4" ' 'the first and last step to profile. For example, "--profile_steps=2,4" '
...@@ -238,21 +259,24 @@ def define_keras_flags( ...@@ -238,21 +259,24 @@ def define_keras_flags(
'Note that profiler has a non-trivial performance overhead, and the ' 'Note that profiler has a non-trivial performance overhead, and the '
'output file can be gigantic if profiling many steps.') 'output file can be gigantic if profiling many steps.')
flags.DEFINE_integer( flags.DEFINE_integer(
name='train_steps', default=None, name='train_steps',
default=None,
help='The number of steps to run for training. If it is larger than ' help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. This flag will be ' '# batches per epoch, then use # batches per epoch. This flag will be '
'ignored if train_epochs is set to be larger than 1. ') 'ignored if train_epochs is set to be larger than 1. ')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='batchnorm_spatial_persistent', default=True, name='batchnorm_spatial_persistent',
default=True,
help='Enable the spacial persistent mode for CuDNN batch norm kernel.') help='Enable the spacial persistent mode for CuDNN batch norm kernel.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_get_next_as_optional', default=False, name='enable_get_next_as_optional',
default=False,
help='Enable get_next_as_optional behavior in DistributedIterator.') help='Enable get_next_as_optional behavior in DistributedIterator.')
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_checkpoint_and_export', default=False, name='enable_checkpoint_and_export',
default=False,
help='Whether to enable a checkpoint callback and export the savedmodel.') help='Whether to enable a checkpoint callback and export the savedmodel.')
flags.DEFINE_string( flags.DEFINE_string(name='tpu', default='', help='TPU address to connect to.')
name='tpu', default='', help='TPU address to connect to.')
flags.DEFINE_integer( flags.DEFINE_integer(
name='steps_per_loop', name='steps_per_loop',
default=None, default=None,
...@@ -270,20 +294,20 @@ def define_keras_flags( ...@@ -270,20 +294,20 @@ def define_keras_flags(
flags.DEFINE_string('model', 'resnet50_v1.5', flags.DEFINE_string('model', 'resnet50_v1.5',
'Name of model preset. (mobilenet, resnet50_v1.5)') 'Name of model preset. (mobilenet, resnet50_v1.5)')
if optimizer: if optimizer:
flags.DEFINE_string('optimizer', 'resnet50_default', flags.DEFINE_string(
'Name of optimizer preset. ' 'optimizer', 'resnet50_default', 'Name of optimizer preset. '
'(mobilenet_default, resnet50_default)') '(mobilenet_default, resnet50_default)')
# TODO(kimjaehong): Replace as general hyper-params not only for mobilenet. # TODO(kimjaehong): Replace as general hyper-params not only for mobilenet.
flags.DEFINE_float('initial_learning_rate_per_sample', 0.00007, flags.DEFINE_float(
'Initial value of learning rate per sample for ' 'initial_learning_rate_per_sample', 0.00007,
'mobilenet_default.') 'Initial value of learning rate per sample for '
'mobilenet_default.')
flags.DEFINE_float('lr_decay_factor', 0.94, flags.DEFINE_float('lr_decay_factor', 0.94,
'Learning rate decay factor for mobilenet_default.') 'Learning rate decay factor for mobilenet_default.')
flags.DEFINE_float('num_epochs_per_decay', 2.5, flags.DEFINE_float('num_epochs_per_decay', 2.5,
'Number of epochs per decay for mobilenet_default.') 'Number of epochs per decay for mobilenet_default.')
if pretrained_filepath: if pretrained_filepath:
flags.DEFINE_string('pretrained_filepath', '', flags.DEFINE_string('pretrained_filepath', '', 'Pretrained file path.')
'Pretrained file path.')
def get_synth_data(height, width, num_channels, num_classes, dtype): def get_synth_data(height, width, num_channels, num_classes, dtype):
...@@ -317,23 +341,24 @@ def get_synth_data(height, width, num_channels, num_classes, dtype): ...@@ -317,23 +341,24 @@ def get_synth_data(height, width, num_channels, num_classes, dtype):
def define_pruning_flags(): def define_pruning_flags():
"""Define flags for pruning methods.""" """Define flags for pruning methods."""
flags.DEFINE_string('pruning_method', None, flags.DEFINE_string(
'Pruning method.' 'pruning_method', None, 'Pruning method.'
'None (no pruning) or polynomial_decay.') 'None (no pruning) or polynomial_decay.')
flags.DEFINE_float('pruning_initial_sparsity', 0.0, flags.DEFINE_float('pruning_initial_sparsity', 0.0,
'Initial sparsity for pruning.') 'Initial sparsity for pruning.')
flags.DEFINE_float('pruning_final_sparsity', 0.5, flags.DEFINE_float('pruning_final_sparsity', 0.5,
'Final sparsity for pruning.') 'Final sparsity for pruning.')
flags.DEFINE_integer('pruning_begin_step', 0, flags.DEFINE_integer('pruning_begin_step', 0, 'Begin step for pruning.')
'Begin step for pruning.') flags.DEFINE_integer('pruning_end_step', 100000, 'End step for pruning.')
flags.DEFINE_integer('pruning_end_step', 100000, flags.DEFINE_integer('pruning_frequency', 100, 'Frequency for pruning.')
'End step for pruning.')
flags.DEFINE_integer('pruning_frequency', 100,
'Frequency for pruning.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height,
dtype=tf.float32, drop_remainder=True): width,
num_channels,
num_classes,
dtype=tf.float32,
drop_remainder=True):
"""Returns an input function that returns a dataset with random data. """Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and This input_fn returns a data set that iterates over a set of random data and
...@@ -355,14 +380,16 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -355,14 +380,16 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
that can be used for iteration. that can be used for iteration.
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
def input_fn(is_training, data_dir, batch_size, *args, **kwargs): def input_fn(is_training, data_dir, batch_size, *args, **kwargs):
"""Returns dataset filled with random data.""" """Returns dataset filled with random data."""
inputs, labels = get_synth_data(height=height, inputs, labels = get_synth_data(
width=width, height=height,
num_channels=num_channels, width=width,
num_classes=num_classes, num_channels=num_channels,
dtype=dtype) num_classes=num_classes,
dtype=dtype)
# Cast to float32 for Keras model. # Cast to float32 for Keras model.
labels = tf.cast(labels, dtype=tf.float32) labels = tf.cast(labels, dtype=tf.float32)
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat() data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
......
...@@ -36,6 +36,7 @@ from __future__ import division ...@@ -36,6 +36,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -78,17 +79,17 @@ def process_record_dataset(dataset, ...@@ -78,17 +79,17 @@ def process_record_dataset(dataset,
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
batch_size: The number of samples per batch. batch_size: The number of samples per batch.
shuffle_buffer: The buffer size to use when shuffling records. A larger shuffle_buffer: The buffer size to use when shuffling records. A larger
value results in better randomness, but smaller values reduce startup value results in better randomness, but smaller values reduce startup time
time and use less memory. and use less memory.
parse_record_fn: A function that takes a raw record and returns the parse_record_fn: A function that takes a raw record and returns the
corresponding (image, label) pair. corresponding (image, label) pair.
dtype: Data type to use for images/features. dtype: Data type to use for images/features.
datasets_num_private_threads: Number of threads for a private datasets_num_private_threads: Number of threads for a private threadpool
threadpool created for all datasets computation. created for all datasets computation.
drop_remainder: A boolean indicates whether to drop the remainder of the drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static. batches. If True, the batch dimension will be static.
tf_data_experimental_slack: Whether to enable tf.data's tf_data_experimental_slack: Whether to enable tf.data's `experimental_slack`
`experimental_slack` option. option.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
...@@ -99,8 +100,8 @@ def process_record_dataset(dataset, ...@@ -99,8 +100,8 @@ def process_record_dataset(dataset,
options.experimental_threading.private_threadpool_size = ( options.experimental_threading.private_threadpool_size = (
datasets_num_private_threads) datasets_num_private_threads)
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
logging.info( logging.info('datasets_num_private_threads: %s',
'datasets_num_private_threads: %s', datasets_num_private_threads) datasets_num_private_threads)
if is_training: if is_training:
# Shuffles records before repeating to respect epoch boundaries. # Shuffles records before repeating to respect epoch boundaries.
...@@ -134,11 +135,13 @@ def get_filenames(is_training, data_dir): ...@@ -134,11 +135,13 @@ def get_filenames(is_training, data_dir):
if is_training: if is_training:
return [ return [
os.path.join(data_dir, 'train-%05d-of-01024' % i) os.path.join(data_dir, 'train-%05d-of-01024' % i)
for i in range(_NUM_TRAIN_FILES)] for i in range(_NUM_TRAIN_FILES)
]
else: else:
return [ return [
os.path.join(data_dir, 'validation-%05d-of-00128' % i) os.path.join(data_dir, 'validation-%05d-of-00128' % i)
for i in range(128)] for i in range(128)
]
def parse_example_proto(example_serialized): def parse_example_proto(example_serialized):
...@@ -165,8 +168,8 @@ def parse_example_proto(example_serialized): ...@@ -165,8 +168,8 @@ def parse_example_proto(example_serialized):
image/encoded: <JPEG encoded string> image/encoded: <JPEG encoded string>
Args: Args:
example_serialized: scalar Tensor tf.string containing a serialized example_serialized: scalar Tensor tf.string containing a serialized Example
Example protocol buffer. protocol buffer.
Returns: Returns:
image_buffer: Tensor tf.string containing the contents of a JPEG file. image_buffer: Tensor tf.string containing the contents of a JPEG file.
...@@ -177,22 +180,24 @@ def parse_example_proto(example_serialized): ...@@ -177,22 +180,24 @@ def parse_example_proto(example_serialized):
""" """
# Dense features in Example proto. # Dense features in Example proto.
feature_map = { feature_map = {
'image/encoded': tf.io.FixedLenFeature([], dtype=tf.string, 'image/encoded':
default_value=''), tf.io.FixedLenFeature([], dtype=tf.string, default_value=''),
'image/class/label': tf.io.FixedLenFeature([], dtype=tf.int64, 'image/class/label':
default_value=-1), tf.io.FixedLenFeature([], dtype=tf.int64, default_value=-1),
'image/class/text': tf.io.FixedLenFeature([], dtype=tf.string, 'image/class/text':
default_value=''), tf.io.FixedLenFeature([], dtype=tf.string, default_value=''),
} }
sparse_float32 = tf.io.VarLenFeature(dtype=tf.float32) sparse_float32 = tf.io.VarLenFeature(dtype=tf.float32)
# Sparse features in Example proto. # Sparse features in Example proto.
feature_map.update( feature_map.update({
{k: sparse_float32 for k in [ k: sparse_float32 for k in [
'image/object/bbox/xmin', 'image/object/bbox/ymin', 'image/object/bbox/xmin', 'image/object/bbox/ymin',
'image/object/bbox/xmax', 'image/object/bbox/ymax']}) 'image/object/bbox/xmax', 'image/object/bbox/ymax'
]
})
features = tf.io.parse_single_example(serialized=example_serialized, features = tf.io.parse_single_example(
features=feature_map) serialized=example_serialized, features=feature_map)
label = tf.cast(features['image/class/label'], dtype=tf.int32) label = tf.cast(features['image/class/label'], dtype=tf.int32)
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
...@@ -218,8 +223,8 @@ def parse_record(raw_record, is_training, dtype): ...@@ -218,8 +223,8 @@ def parse_record(raw_record, is_training, dtype):
through preprocessing steps (cropping, flipping, and so on). through preprocessing steps (cropping, flipping, and so on).
Args: Args:
raw_record: scalar Tensor tf.string containing a serialized raw_record: scalar Tensor tf.string containing a serialized Example protocol
Example protocol buffer. buffer.
is_training: A boolean denoting whether the input is for training. is_training: A boolean denoting whether the input is for training.
dtype: data type to use for images/features. dtype: data type to use for images/features.
...@@ -240,8 +245,9 @@ def parse_record(raw_record, is_training, dtype): ...@@ -240,8 +245,9 @@ def parse_record(raw_record, is_training, dtype):
# Subtract one so that labels are in [0, 1000), and cast to float32 for # Subtract one so that labels are in [0, 1000), and cast to float32 for
# Keras model. # Keras model.
label = tf.cast(tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1, label = tf.cast(
dtype=tf.float32) tf.cast(tf.reshape(label, shape=[1]), dtype=tf.int32) - 1,
dtype=tf.float32)
return image, label return image, label
...@@ -262,12 +268,14 @@ def get_parse_record_fn(use_keras_image_data_format=False): ...@@ -262,12 +268,14 @@ def get_parse_record_fn(use_keras_image_data_format=False):
Returns: Returns:
Function to use for parsing the records. Function to use for parsing the records.
""" """
def parse_record_fn(raw_record, is_training, dtype): def parse_record_fn(raw_record, is_training, dtype):
image, label = parse_record(raw_record, is_training, dtype) image, label = parse_record(raw_record, is_training, dtype)
if use_keras_image_data_format: if use_keras_image_data_format:
if tf.keras.backend.image_data_format() == 'channels_first': if tf.keras.backend.image_data_format() == 'channels_first':
image = tf.transpose(image, perm=[2, 0, 1]) image = tf.transpose(image, perm=[2, 0, 1])
return image, label return image, label
return parse_record_fn return parse_record_fn
...@@ -295,11 +303,11 @@ def input_fn(is_training, ...@@ -295,11 +303,11 @@ def input_fn(is_training,
`tf.distribute.Strategy`. `tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static. batches. If True, the batch dimension will be static.
tf_data_experimental_slack: Whether to enable tf.data's tf_data_experimental_slack: Whether to enable tf.data's `experimental_slack`
`experimental_slack` option. option.
training_dataset_cache: Whether to cache the training dataset on workers. training_dataset_cache: Whether to cache the training dataset on workers.
Typically used to improve training performance when training data is in Typically used to improve training performance when training data is in
remote storage and can fit into worker memory. remote storage and can fit into worker memory.
filenames: Optional field for providing the file names of the TFRecords. filenames: Optional field for providing the file names of the TFRecords.
Returns: Returns:
...@@ -357,8 +365,8 @@ def _decode_crop_and_flip(image_buffer, bbox, num_channels): ...@@ -357,8 +365,8 @@ def _decode_crop_and_flip(image_buffer, bbox, num_channels):
Args: Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer. image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as where each coordinate is [0, 1) and the coordinates are arranged as [ymin,
[ymin, xmin, ymax, xmax]. xmin, ymax, xmax].
num_channels: Integer depth of the image buffer for decoding. num_channels: Integer depth of the image buffer for decoding.
Returns: Returns:
...@@ -414,8 +422,8 @@ def _central_crop(image, crop_height, crop_width): ...@@ -414,8 +422,8 @@ def _central_crop(image, crop_height, crop_width):
crop_top = amount_to_be_cropped_h // 2 crop_top = amount_to_be_cropped_h // 2
amount_to_be_cropped_w = (width - crop_width) amount_to_be_cropped_w = (width - crop_width)
crop_left = amount_to_be_cropped_w // 2 crop_left = amount_to_be_cropped_w // 2
return tf.slice( return tf.slice(image, [crop_top, crop_left, 0],
image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) [crop_height, crop_width, -1])
def _mean_image_subtraction(image, means, num_channels): def _mean_image_subtraction(image, means, num_channels):
...@@ -463,8 +471,8 @@ def _smallest_size_at_least(height, width, resize_min): ...@@ -463,8 +471,8 @@ def _smallest_size_at_least(height, width, resize_min):
Args: Args:
height: an int32 scalar tensor indicating the current height. height: an int32 scalar tensor indicating the current height.
width: an int32 scalar tensor indicating the current width. width: an int32 scalar tensor indicating the current width.
resize_min: A python integer or scalar `Tensor` indicating the size of resize_min: A python integer or scalar `Tensor` indicating the size of the
the smallest side after resize. smallest side after resize.
Returns: Returns:
new_height: an int32 scalar tensor indicating the new height. new_height: an int32 scalar tensor indicating the new height.
...@@ -490,8 +498,8 @@ def _aspect_preserving_resize(image, resize_min): ...@@ -490,8 +498,8 @@ def _aspect_preserving_resize(image, resize_min):
Args: Args:
image: A 3-D image `Tensor`. image: A 3-D image `Tensor`.
resize_min: A python integer or scalar `Tensor` indicating the size of resize_min: A python integer or scalar `Tensor` indicating the size of the
the smallest side after resize. smallest side after resize.
Returns: Returns:
resized_image: A 3-D tensor containing the resized image. resized_image: A 3-D tensor containing the resized image.
...@@ -520,12 +528,17 @@ def _resize_image(image, height, width): ...@@ -520,12 +528,17 @@ def _resize_image(image, height, width):
dimensions have the shape [height, width]. dimensions have the shape [height, width].
""" """
return tf.compat.v1.image.resize( return tf.compat.v1.image.resize(
image, [height, width], method=tf.image.ResizeMethod.BILINEAR, image, [height, width],
method=tf.image.ResizeMethod.BILINEAR,
align_corners=False) align_corners=False)
def preprocess_image(image_buffer, bbox, output_height, output_width, def preprocess_image(image_buffer,
num_channels, is_training=False): bbox,
output_height,
output_width,
num_channels,
is_training=False):
"""Preprocesses the given image. """Preprocesses the given image.
Preprocessing includes decoding, cropping, and resizing for both training Preprocessing includes decoding, cropping, and resizing for both training
...@@ -535,8 +548,8 @@ def preprocess_image(image_buffer, bbox, output_height, output_width, ...@@ -535,8 +548,8 @@ def preprocess_image(image_buffer, bbox, output_height, output_width,
Args: Args:
image_buffer: scalar string Tensor representing the raw JPEG image buffer. image_buffer: scalar string Tensor representing the raw JPEG image buffer.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as where each coordinate is [0, 1) and the coordinates are arranged as [ymin,
[ymin, xmin, ymax, xmax]. xmin, ymax, xmax].
output_height: The height of the image after preprocessing. output_height: The height of the image after preprocessing.
output_width: The width of the image after preprocessing. output_width: The width of the image after preprocessing.
num_channels: Integer depth of the image buffer for decoding. num_channels: Integer depth of the image buffer for decoding.
......
...@@ -52,7 +52,4 @@ class ResNetModelConfig(base_configs.ModelConfig): ...@@ -52,7 +52,4 @@ class ResNetModelConfig(base_configs.ModelConfig):
boundaries=[30, 60, 80], boundaries=[30, 60, 80],
warmup_epochs=5, warmup_epochs=5,
scale_by_batch_size=1. / 256., scale_by_batch_size=1. / 256.,
multipliers=[0.1 / 256, multipliers=[0.1 / 256, 0.01 / 256, 0.001 / 256, 0.0001 / 256]))
0.01 / 256,
0.001 / 256,
0.0001 / 256]))
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import math import math
import os import os
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment