"vscode:/vscode.git/clone" did not exist on "64350affc5767e7ce3fb211d8145b5c9d18017d8"
Commit e08b6286 authored by rxsang's avatar rxsang Committed by Toby Boyd
Browse files

Ensure static shapes when enabling XLA in Resnet Keras model in graph mode. (#6558)

* Revert "Revert " Ensure static shapes when enabling XLA in Resnet Keras model (#6508)" (#6517)"

This reverts commit cc9eef76.

* Set `batch_size` to keras.Input in non-eager mode.

Eager mode currently has OOM problem.

* Add comments for enable_eager flag.

* Always set drop_remainder=True.

* Only set drop_remainder=True for XLA.
parent 80dde852
...@@ -167,7 +167,8 @@ def input_fn(is_training, ...@@ -167,7 +167,8 @@ def input_fn(is_training,
datasets_num_private_threads=None, datasets_num_private_threads=None,
num_parallel_batches=1, num_parallel_batches=1,
parse_record_fn=parse_record, parse_record_fn=parse_record,
input_context=None): input_context=None,
drop_remainder=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -181,6 +182,8 @@ def input_fn(is_training, ...@@ -181,6 +182,8 @@ def input_fn(is_training,
parse_record_fn: Function to use for parsing the records. parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`. `tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -217,7 +220,8 @@ def input_fn(is_training, ...@@ -217,7 +220,8 @@ def input_fn(is_training,
num_epochs=num_epochs, num_epochs=num_epochs,
dtype=dtype, dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads, datasets_num_private_threads=datasets_num_private_threads,
num_parallel_batches=num_parallel_batches num_parallel_batches=num_parallel_batches,
drop_remainder=drop_remainder
) )
......
...@@ -286,7 +286,7 @@ def define_keras_flags(): ...@@ -286,7 +286,7 @@ def define_keras_flags():
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32): dtype=tf.float32, drop_remainder=True):
"""Returns an input function that returns a dataset with random data. """Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and This input_fn returns a data set that iterates over a set of random data and
...@@ -301,6 +301,8 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -301,6 +301,8 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
num_classes: Number of classes that should be represented in the fake labels num_classes: Number of classes that should be represented in the fake labels
tensor tensor
dtype: Data type for features/images. dtype: Data type for features/images.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns: Returns:
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
...@@ -327,7 +329,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -327,7 +329,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat() data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
# `drop_remainder` will make dataset produce outputs with known shapes. # `drop_remainder` will make dataset produce outputs with known shapes.
data = data.batch(batch_size, drop_remainder=True) data = data.batch(batch_size, drop_remainder=drop_remainder)
data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return data return data
......
...@@ -134,11 +134,16 @@ def run(flags_obj): ...@@ -134,11 +134,16 @@ def run(flags_obj):
width=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS, num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype) dtype=dtype,
drop_remainder=True)
else: else:
distribution_utils.undo_set_up_synthetic_data() distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn input_fn = imagenet_main.input_fn
# When `enable_xla` is True, we always drop the remainder of the batches
# in the dataset, as XLA-GPU doesn't support dynamic shapes.
drop_remainder = flags_obj.enable_xla
train_input_dataset = input_fn( train_input_dataset = input_fn(
is_training=True, is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
...@@ -146,7 +151,8 @@ def run(flags_obj): ...@@ -146,7 +151,8 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads, datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype) dtype=dtype,
drop_remainder=drop_remainder)
eval_input_dataset = None eval_input_dataset = None
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
...@@ -156,7 +162,8 @@ def run(flags_obj): ...@@ -156,7 +162,8 @@ def run(flags_obj):
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=parse_record_keras,
dtype=dtype) dtype=dtype,
drop_remainder=drop_remainder)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
...@@ -166,11 +173,25 @@ def run(flags_obj): ...@@ -166,11 +173,25 @@ def run(flags_obj):
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
if flags_obj.enable_xla and not flags_obj.enable_eager:
# TODO(b/129861005): Fix OOM issue in eager mode when setting
# `batch_size` in keras.Input layer.
if strategy and strategy.num_replicas_in_sync > 1:
# TODO(b/129791381): Specify `per_replica_batch_size` value in
# DistributionStrategy multi-replica case.
per_replica_batch_size = None
else:
per_replica_batch_size = flags_obj.batch_size
else:
per_replica_batch_size = None
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES) model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
else: else:
model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES, model = resnet_model.resnet50(
dtype=dtype) num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype,
batch_size=per_replica_batch_size)
model.compile(loss='sparse_categorical_crossentropy', model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
......
...@@ -174,7 +174,7 @@ def conv_block(input_tensor, ...@@ -174,7 +174,7 @@ def conv_block(input_tensor,
return x return x
def resnet50(num_classes, dtype='float32'): def resnet50(num_classes, dtype='float32', batch_size=None):
# TODO(tfboyd): add training argument, just lik resnet56. # TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture. """Instantiates the ResNet50 architecture.
...@@ -185,7 +185,8 @@ def resnet50(num_classes, dtype='float32'): ...@@ -185,7 +185,8 @@ def resnet50(num_classes, dtype='float32'):
A Keras model instance. A Keras model instance.
""" """
input_shape = (224, 224, 3) input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype) img_input = layers.Input(shape=input_shape, dtype=dtype,
batch_size=batch_size)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)), x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
......
...@@ -53,7 +53,8 @@ def process_record_dataset(dataset, ...@@ -53,7 +53,8 @@ def process_record_dataset(dataset,
num_epochs=1, num_epochs=1,
dtype=tf.float32, dtype=tf.float32,
datasets_num_private_threads=None, datasets_num_private_threads=None,
num_parallel_batches=1): num_parallel_batches=1,
drop_remainder=False):
"""Given a Dataset with raw records, return an iterator over the records. """Given a Dataset with raw records, return an iterator over the records.
Args: Args:
...@@ -70,6 +71,8 @@ def process_record_dataset(dataset, ...@@ -70,6 +71,8 @@ def process_record_dataset(dataset,
datasets_num_private_threads: Number of threads for a private datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation. threadpool created for all datasets computation.
num_parallel_batches: Number of parallel batches for tf.data. num_parallel_batches: Number of parallel batches for tf.data.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
...@@ -102,7 +105,7 @@ def process_record_dataset(dataset, ...@@ -102,7 +105,7 @@ def process_record_dataset(dataset,
dataset = dataset.map( dataset = dataset.map(
lambda value: parse_record_fn(value, is_training, dtype), lambda value: parse_record_fn(value, is_training, dtype),
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=False) dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment