Commit e08b6286 authored by rxsang's avatar rxsang Committed by Toby Boyd
Browse files

Ensure static shapes when enabling XLA in Resnet Keras model in graph mode. (#6558)

* Revert "Revert " Ensure static shapes when enabling XLA in Resnet Keras model (#6508)" (#6517)"

This reverts commit cc9eef76.

* Set `batch_size` to keras.Input in non-eager mode.

Eager mode currently has OOM problem.

* Add comments for enable_eager flag.

* Always set drop_remainder=True.

* Only set drop_remainder=True for XLA.
parent 80dde852
......@@ -167,7 +167,8 @@ def input_fn(is_training,
datasets_num_private_threads=None,
num_parallel_batches=1,
parse_record_fn=parse_record,
input_context=None):
input_context=None,
drop_remainder=False):
"""Input function which provides batches for train or eval.
Args:
......@@ -181,6 +182,8 @@ def input_fn(is_training,
parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns:
A dataset that can be used for iteration.
......@@ -217,7 +220,8 @@ def input_fn(is_training,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
num_parallel_batches=num_parallel_batches
num_parallel_batches=num_parallel_batches,
drop_remainder=drop_remainder
)
......
......@@ -286,7 +286,7 @@ def define_keras_flags():
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
dtype=tf.float32, drop_remainder=True):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
......@@ -301,6 +301,8 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns:
An input_fn that can be used in place of a real one to return a dataset
......@@ -327,7 +329,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
# `drop_remainder` will make dataset produce outputs with known shapes.
data = data.batch(batch_size, drop_remainder=True)
data = data.batch(batch_size, drop_remainder=drop_remainder)
data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return data
......
......@@ -134,11 +134,16 @@ def run(flags_obj):
width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
dtype=dtype,
drop_remainder=True)
else:
distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn
# When `enable_xla` is True, we always drop the remainder of the batches
# in the dataset, as XLA-GPU doesn't support dynamic shapes.
drop_remainder = flags_obj.enable_xla
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
......@@ -146,7 +151,8 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
dtype=dtype,
drop_remainder=drop_remainder)
eval_input_dataset = None
if not flags_obj.skip_eval:
......@@ -156,7 +162,8 @@ def run(flags_obj):
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
dtype=dtype)
dtype=dtype,
drop_remainder=drop_remainder)
with strategy_scope:
optimizer = keras_common.get_optimizer()
......@@ -166,11 +173,25 @@ def run(flags_obj):
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
if flags_obj.enable_xla and not flags_obj.enable_eager:
# TODO(b/129861005): Fix OOM issue in eager mode when setting
# `batch_size` in keras.Input layer.
if strategy and strategy.num_replicas_in_sync > 1:
# TODO(b/129791381): Specify `per_replica_batch_size` value in
# DistributionStrategy multi-replica case.
per_replica_batch_size = None
else:
per_replica_batch_size = flags_obj.batch_size
else:
per_replica_batch_size = None
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
else:
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
model = resnet_model.resnet50(
num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype,
batch_size=per_replica_batch_size)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
......
......@@ -174,7 +174,7 @@ def conv_block(input_tensor,
return x
def resnet50(num_classes, dtype='float32'):
def resnet50(num_classes, dtype='float32', batch_size=None):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
......@@ -185,7 +185,8 @@ def resnet50(num_classes, dtype='float32'):
A Keras model instance.
"""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype)
img_input = layers.Input(shape=input_shape, dtype=dtype,
batch_size=batch_size)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
......
......@@ -53,7 +53,8 @@ def process_record_dataset(dataset,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1):
num_parallel_batches=1,
drop_remainder=False):
"""Given a Dataset with raw records, return an iterator over the records.
Args:
......@@ -70,6 +71,8 @@ def process_record_dataset(dataset,
datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation.
num_parallel_batches: Number of parallel batches for tf.data.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns:
Dataset of (image, label) pairs ready for iteration.
......@@ -102,7 +105,7 @@ def process_record_dataset(dataset,
dataset = dataset.map(
lambda value: parse_record_fn(value, is_training, dtype),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=False)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment