Unverified Commit cc9eef76 authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Revert " Ensure static shapes when enabling XLA in Resnet Keras model (#6508)" (#6517)

Reason: break 1-gpu nightly test.

This reverts commit 371645fc.
parent 154d3ffa
...@@ -167,8 +167,7 @@ def input_fn(is_training, ...@@ -167,8 +167,7 @@ def input_fn(is_training,
datasets_num_private_threads=None, datasets_num_private_threads=None,
num_parallel_batches=1, num_parallel_batches=1,
parse_record_fn=parse_record, parse_record_fn=parse_record,
input_context=None, input_context=None):
drop_remainder=False):
"""Input function which provides batches for train or eval. """Input function which provides batches for train or eval.
Args: Args:
...@@ -182,8 +181,6 @@ def input_fn(is_training, ...@@ -182,8 +181,6 @@ def input_fn(is_training,
parse_record_fn: Function to use for parsing the records. parse_record_fn: Function to use for parsing the records.
input_context: A `tf.distribute.InputContext` object passed in by input_context: A `tf.distribute.InputContext` object passed in by
`tf.distribute.Strategy`. `tf.distribute.Strategy`.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns: Returns:
A dataset that can be used for iteration. A dataset that can be used for iteration.
...@@ -220,8 +217,7 @@ def input_fn(is_training, ...@@ -220,8 +217,7 @@ def input_fn(is_training,
num_epochs=num_epochs, num_epochs=num_epochs,
dtype=dtype, dtype=dtype,
datasets_num_private_threads=datasets_num_private_threads, datasets_num_private_threads=datasets_num_private_threads,
num_parallel_batches=num_parallel_batches, num_parallel_batches=num_parallel_batches
drop_remainder=drop_remainder
) )
......
...@@ -223,7 +223,7 @@ def define_keras_flags(): ...@@ -223,7 +223,7 @@ def define_keras_flags():
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
dtype=tf.float32, drop_remainder=True): dtype=tf.float32):
"""Returns an input function that returns a dataset with random data. """Returns an input function that returns a dataset with random data.
This input_fn returns a data set that iterates over a set of random data and This input_fn returns a data set that iterates over a set of random data and
...@@ -238,8 +238,6 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -238,8 +238,6 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
num_classes: Number of classes that should be represented in the fake labels num_classes: Number of classes that should be represented in the fake labels
tensor tensor
dtype: Data type for features/images. dtype: Data type for features/images.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns: Returns:
An input_fn that can be used in place of a real one to return a dataset An input_fn that can be used in place of a real one to return a dataset
...@@ -266,7 +264,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes, ...@@ -266,7 +264,7 @@ def get_synth_input_fn(height, width, num_channels, num_classes,
data = tf.data.Dataset.from_tensors((inputs, labels)).repeat() data = tf.data.Dataset.from_tensors((inputs, labels)).repeat()
# `drop_remainder` will make dataset produce outputs with known shapes. # `drop_remainder` will make dataset produce outputs with known shapes.
data = data.batch(batch_size, drop_remainder=drop_remainder) data = data.batch(batch_size, drop_remainder=True)
data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) data = data.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
return data return data
......
...@@ -135,16 +135,11 @@ def run(flags_obj): ...@@ -135,16 +135,11 @@ def run(flags_obj):
width=imagenet_main.DEFAULT_IMAGE_SIZE, width=imagenet_main.DEFAULT_IMAGE_SIZE,
num_channels=imagenet_main.NUM_CHANNELS, num_channels=imagenet_main.NUM_CHANNELS,
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype, dtype=dtype)
drop_remainder=True)
else: else:
distribution_utils.undo_set_up_synthetic_data() distribution_utils.undo_set_up_synthetic_data()
input_fn = imagenet_main.input_fn input_fn = imagenet_main.input_fn
# When `enable_xla` is True, we always drop the remainder of the batches
# in the dataset, as XLA-GPU doesn't support dynamic shapes.
drop_remainder = flags_obj.enable_xla
train_input_dataset = input_fn( train_input_dataset = input_fn(
is_training=True, is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
...@@ -152,8 +147,7 @@ def run(flags_obj): ...@@ -152,8 +147,7 @@ def run(flags_obj):
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=parse_record_keras,
datasets_num_private_threads=flags_obj.datasets_num_private_threads, datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype, dtype=dtype)
drop_remainder=drop_remainder)
eval_input_dataset = None eval_input_dataset = None
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
...@@ -163,8 +157,7 @@ def run(flags_obj): ...@@ -163,8 +157,7 @@ def run(flags_obj):
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
num_epochs=flags_obj.train_epochs, num_epochs=flags_obj.train_epochs,
parse_record_fn=parse_record_keras, parse_record_fn=parse_record_keras,
dtype=dtype, dtype=dtype)
drop_remainder=drop_remainder)
with strategy_scope: with strategy_scope:
optimizer = keras_common.get_optimizer() optimizer = keras_common.get_optimizer()
...@@ -174,23 +167,11 @@ def run(flags_obj): ...@@ -174,23 +167,11 @@ def run(flags_obj):
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer( optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj)) optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
if flags_obj.enable_xla:
if strategy and strategy.num_replicas_in_sync > 1:
# TODO(b/129791381): Specify `per_replica_batch_size` value in
# DistributionStrategy multi-replica case.
per_replica_batch_size = None
else:
per_replica_batch_size = flags_obj.batch_size
else:
per_replica_batch_size = None
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES) model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
else: else:
model = resnet_model.resnet50( model = resnet_model.resnet50(num_classes=imagenet_main.NUM_CLASSES,
num_classes=imagenet_main.NUM_CLASSES, dtype=dtype)
dtype=dtype,
batch_size=per_replica_batch_size)
model.compile(loss='sparse_categorical_crossentropy', model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
......
...@@ -174,7 +174,7 @@ def conv_block(input_tensor, ...@@ -174,7 +174,7 @@ def conv_block(input_tensor,
return x return x
def resnet50(num_classes, dtype='float32', batch_size=None): def resnet50(num_classes, dtype='float32'):
# TODO(tfboyd): add training argument, just lik resnet56. # TODO(tfboyd): add training argument, just lik resnet56.
"""Instantiates the ResNet50 architecture. """Instantiates the ResNet50 architecture.
...@@ -185,8 +185,7 @@ def resnet50(num_classes, dtype='float32', batch_size=None): ...@@ -185,8 +185,7 @@ def resnet50(num_classes, dtype='float32', batch_size=None):
A Keras model instance. A Keras model instance.
""" """
input_shape = (224, 224, 3) input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape, dtype=dtype, img_input = layers.Input(shape=input_shape, dtype=dtype)
batch_size=batch_size)
if backend.image_data_format() == 'channels_first': if backend.image_data_format() == 'channels_first':
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)), x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
......
...@@ -53,8 +53,7 @@ def process_record_dataset(dataset, ...@@ -53,8 +53,7 @@ def process_record_dataset(dataset,
num_epochs=1, num_epochs=1,
dtype=tf.float32, dtype=tf.float32,
datasets_num_private_threads=None, datasets_num_private_threads=None,
num_parallel_batches=1, num_parallel_batches=1):
drop_remainder=False):
"""Given a Dataset with raw records, return an iterator over the records. """Given a Dataset with raw records, return an iterator over the records.
Args: Args:
...@@ -71,8 +70,6 @@ def process_record_dataset(dataset, ...@@ -71,8 +70,6 @@ def process_record_dataset(dataset,
datasets_num_private_threads: Number of threads for a private datasets_num_private_threads: Number of threads for a private
threadpool created for all datasets computation. threadpool created for all datasets computation.
num_parallel_batches: Number of parallel batches for tf.data. num_parallel_batches: Number of parallel batches for tf.data.
drop_remainder: A boolean indicates whether to drop the remainder of the
batches. If True, the batch dimension will be static.
Returns: Returns:
Dataset of (image, label) pairs ready for iteration. Dataset of (image, label) pairs ready for iteration.
...@@ -105,7 +102,7 @@ def process_record_dataset(dataset, ...@@ -105,7 +102,7 @@ def process_record_dataset(dataset,
dataset = dataset.map( dataset = dataset.map(
lambda value: parse_record_fn(value, is_training, dtype), lambda value: parse_record_fn(value, is_training, dtype),
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) dataset = dataset.batch(batch_size, drop_remainder=False)
# Operations between the final prefetch and the get_next call to the iterator # Operations between the final prefetch and the get_next call to the iterator
# will happen synchronously during run time. We prefetch here again to # will happen synchronously during run time. We prefetch here again to
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment