Unverified Commit 371645fc authored by rxsang's avatar rxsang Committed by GitHub
Browse files

Ensure static shapes when enabling XLA in Resnet Keras model (#6508)

* Update resnet_model.py

* Ensure static shapes when enabling XLA.

* Define `drop_remainder` as a variable.

* Handles per_replica_batch_size in non-XLA mode

* Remove trailing whitespace.
parent d9823dae
......@@ -167,7 +167,8 @@ def input_fn(is_training,
datasets_num_private_threads=None,
num_parallel_batches=1,
parse_record_fn=parse_record,
input_context=None):
input_context=None,
drop_remainder=False):
"""Input function which provides batches for train or eval.
Args:
......@@ -181,6 +182,8 @@ def input_fn(is_training,
parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns:
A dataset that can be used for iteration.
......@@ -217,7 +220,8 @@ def input_fn(is_training,
num_epochs=num_epochs,
dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads,
num_parallel_batches=num_parallel_batches
num_parallel_batches=num_parallel_batches,
drop_remainder=drop_remainder
)
......
......@@ -218,7 +218,7 @@ def define_keras_flags():
def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32):
dtype=tf.float32, drop_remainder=True):
"""Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and
......@@ -233,6 +233,8 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
num_classes: Number of classes that should be represented in the fake labels
tensor
dtype: Data type for features/images.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns:
An input_fn that can be used in place of a real one to return a dataset
......@@ -259,7 +261,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
# `drop_remainder` will make dataset produce outputs with known shapes.
data = data.batch(batch_size, drop_remainder=True)
data = data.batch(batch_size, drop_remainder=drop_remainder)
data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return data
......
......@@ -134,11 +134,16 @@ def run(flags_obj):
width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
dtype=dtype,
drop_remainder=True)
else:
distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn
# When `enable_xla` is True, we always drop the remainder of the batches
# in the dataset, as XLA-GPU doesn't support dynamic shapes.
drop_remainder = flags_obj.enable_xla
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
......@@ -146,7 +151,8 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype)
dtype=dtype,
drop_remainder=drop_remainder)
eval_input_dataset = None
if not flags_obj.skip_eval:
......@@ -156,7 +162,8 @@ def run(flags_obj):
batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras,
dtype=dtype)
dtype=dtype,
drop_remainder=drop_remainder)
with strategy_scope:
optimizer = keras_common.get_optimizer()
......@@ -166,11 +173,22 @@ def run(flags_obj):
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
if flags_obj.enable_xla:
if strategy:
per_replica_batch_size = (
flags_obj.batch_size // strategy.num_replicas_in_sync)
else:
per_replica_batch_size = flags_obj.batch_size
else:
per_replica_batch_size = None
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
else:
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype)
model = resnet_model.resnet50(
num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype,
batch_size=per_replica_batch_size)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
......
......@@ -174,7 +174,7 @@ def conv_block(input_tensor,
return x
def resnet50(num_classes, dtype='float32'):
def resnet50(num_classes, dtype='float32', batch_size=None):
# TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture.
......@@ -185,7 +185,8 @@ def resnet50(num_classes, dtype='float32'):
A Keras model instance.
"""
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype)
img_input = layers.Input(shape=input_shape, dtype=dtype,
batch_size=batch_size)
if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
......
......@@ -53,7 +53,8 @@ def process_record_dataset(dataset,
num_epochs=1,
dtype=tf.float32,
datasets_num_private_threads=None,
num_parallel_batches=1):
num_parallel_batches=1,
drop_remainder=False):
"""Given a Dataset with raw records, return an iterator over the records.
Args:
......@@ -70,6 +71,8 @@ def process_record_dataset(dataset,
datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation.
num_parallel_batches: Number of parallel batches for tf.data.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns:
Dataset of (image, label) pairs ready for iteration.
......@@ -102,7 +105,7 @@ def process_record_dataset(dataset,
dataset = dataset.map(
lambda value: parse_record_fn(value, is_training, dtype),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=False)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
# Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment