Commit a8b6963c authored by Jing Li's avatar Jing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 272915002
parent b045ce7d
...@@ -71,6 +71,20 @@ def build_stats(train_result, eval_result, time_callback): ...@@ -71,6 +71,20 @@ def build_stats(train_result, eval_result, time_callback):
def get_input_dataset(flags_obj, strategy): def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets.""" """Returns the test and train input datasets."""
dtype = flags_core.get_tf_dtype(flags_obj) dtype = flags_core.get_tf_dtype(flags_obj)
use_dataset_fn = isinstance(strategy, tf.distribute.experimental.TPUStrategy)
batch_size = flags_obj.batch_size
if use_dataset_fn:
if batch_size % strategy.num_replicas_in_sync != 0:
raise ValueError(
'Batch size must be divisible by number of replicas : {}'.format(
strategy.num_replicas_in_sync))
# As auto rebatching is not supported in
# `experimental_distribute_datasets_from_function()` API, which is
# required when cloning dataset to multiple workers in eager mode,
# we use per-replica batch size.
batch_size = int(batch_size / strategy.num_replicas_in_sync)
if flags_obj.use_synthetic_data: if flags_obj.use_synthetic_data:
input_fn = common.get_synth_input_fn( input_fn = common.get_synth_input_fn(
height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE, height=imagenet_preprocessing.DEFAULT_IMAGE_SIZE,
...@@ -82,34 +96,51 @@ def get_input_dataset(flags_obj, strategy): ...@@ -82,34 +96,51 @@ def get_input_dataset(flags_obj, strategy):
else: else:
input_fn = imagenet_preprocessing.input_fn input_fn = imagenet_preprocessing.input_fn
train_ds = input_fn( def _train_dataset_fn(ctx=None):
is_training=True, train_ds = input_fn(
data_dir=flags_obj.data_dir, is_training=True,
batch_size=flags_obj.batch_size, data_dir=flags_obj.data_dir,
parse_record_fn=imagenet_preprocessing.parse_record, batch_size=batch_size,
datasets_num_private_threads=flags_obj.datasets_num_private_threads, parse_record_fn=imagenet_preprocessing.parse_record,
dtype=dtype) datasets_num_private_threads=flags_obj.datasets_num_private_threads,
dtype=dtype,
input_context=ctx,
drop_remainder=True)
return train_ds
if strategy: if strategy:
train_ds = strategy.experimental_distribute_dataset(train_ds) if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
train_ds = strategy.experimental_distribute_datasets_from_function(_train_dataset_fn)
else:
train_ds = strategy.experimental_distribute_dataset(_train_dataset_fn())
else:
train_ds = _train_dataset_fn()
test_ds = None test_ds = None
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
test_ds = input_fn( def _test_data_fn(ctx=None):
is_training=False, test_ds = input_fn(
data_dir=flags_obj.data_dir, is_training=False,
batch_size=flags_obj.batch_size, data_dir=flags_obj.data_dir,
parse_record_fn=imagenet_preprocessing.parse_record, batch_size=batch_size,
dtype=dtype) parse_record_fn=imagenet_preprocessing.parse_record,
dtype=dtype,
input_context=ctx)
return test_ds
if strategy: if strategy:
test_ds = strategy.experimental_distribute_dataset(test_ds) if isinstance(strategy, tf.distribute.experimental.TPUStrategy):
test_ds = strategy.experimental_distribute_datasets_from_function(_test_data_fn)
else:
test_ds = strategy.experimental_distribute_dataset(_test_data_fn())
else:
test_ds = _test_data_fn()
return train_ds, test_ds return train_ds, test_ds
def get_num_train_iterations(flags_obj): def get_num_train_iterations(flags_obj):
"""Returns the number of training stesps, train and test epochs.""" """Returns the number of training steps, train and test epochs."""
train_steps = ( train_steps = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
train_epochs = flags_obj.train_epochs train_epochs = flags_obj.train_epochs
...@@ -124,6 +155,15 @@ def get_num_train_iterations(flags_obj): ...@@ -124,6 +155,15 @@ def get_num_train_iterations(flags_obj):
return train_steps, train_epochs, eval_steps return train_steps, train_epochs, eval_steps
def _steps_to_run(steps_in_current_epoch, steps_per_epoch, steps_per_loop):
"""Calculates steps to run on device."""
if steps_per_loop <= 0:
raise ValueError('steps_per_loop should be positive integer.')
if steps_per_loop == 1:
return steps_per_loop
return min(steps_per_loop, steps_per_epoch - steps_in_current_epoch)
def run(flags_obj): def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using custom training loops. """Run ResNet ImageNet training and eval loop using custom training loops.
...@@ -152,33 +192,45 @@ def run(flags_obj): ...@@ -152,33 +192,45 @@ def run(flags_obj):
num_gpus=flags_obj.num_gpus, num_gpus=flags_obj.num_gpus,
num_workers=distribution_utils.configure_cluster(), num_workers=distribution_utils.configure_cluster(),
all_reduce_alg=flags_obj.all_reduce_alg, all_reduce_alg=flags_obj.all_reduce_alg,
num_packs=flags_obj.num_packs) num_packs=flags_obj.num_packs,
tpu_address=flags_obj.tpu)
train_ds, test_ds = get_input_dataset(flags_obj, strategy) train_ds, test_ds = get_input_dataset(flags_obj, strategy)
train_steps, train_epochs, eval_steps = get_num_train_iterations(flags_obj) per_epoch_steps, train_epochs, eval_steps = get_num_train_iterations(
flags_obj)
steps_per_loop = min(flags_obj.steps_per_loop, per_epoch_steps)
logging.info("Training %d epochs, each epoch has %d steps, "
"total steps: %d; Eval %d steps",
train_epochs, per_epoch_steps, train_epochs * per_epoch_steps,
eval_steps)
time_callback = keras_utils.TimeHistory(flags_obj.batch_size, time_callback = keras_utils.TimeHistory(flags_obj.batch_size,
flags_obj.log_steps) flags_obj.log_steps)
strategy_scope = distribution_utils.get_strategy_scope(strategy) with distribution_utils.get_strategy_scope(strategy):
with strategy_scope:
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES, num_classes=imagenet_preprocessing.NUM_CLASSES,
batch_size=flags_obj.batch_size, batch_size=flags_obj.batch_size,
use_l2_regularizer=not flags_obj.single_l2_loss_op) use_l2_regularizer=not flags_obj.single_l2_loss_op)
optimizer = tf.keras.optimizers.SGD( lr_schedule = common.PiecewiseConstantDecayWithWarmup(
learning_rate=common.BASE_LEARNING_RATE, momentum=0.9, batch_size=flags_obj.batch_size,
nesterov=True) epoch_size=imagenet_preprocessing.NUM_IMAGES['train'],
warmup_epochs=common.LR_SCHEDULE[0][1],
if flags_obj.fp16_implementation == "graph_rewrite": boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
optimizer = common.get_optimizer(lr_schedule)
if flags_obj.fp16_implementation == 'graph_rewrite':
if not flags_obj.use_tf_function: if not flags_obj.use_tf_function:
raise ValueError("--fp16_implementation=graph_rewrite requires " raise ValueError('--fp16_implementation=graph_rewrite requires '
"--use_tf_function to be true") '--use_tf_function to be true')
loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128) loss_scale = flags_core.get_loss_scale(flags_obj, default_for_fp16=128)
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(
optimizer, loss_scale) optimizer, loss_scale)
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32) 'training_accuracy', dtype=tf.float32)
test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32) test_loss = tf.keras.metrics.Mean('test_loss', dtype=tf.float32)
...@@ -187,55 +239,56 @@ def run(flags_obj): ...@@ -187,55 +239,56 @@ def run(flags_obj):
trainable_variables = model.trainable_variables trainable_variables = model.trainable_variables
def train_step(train_ds_inputs): def step_fn(inputs):
"""Training StepFn.""" """Per-Replica StepFn."""
def step_fn(inputs): images, labels = inputs
"""Per-Replica StepFn.""" with tf.GradientTape() as tape:
images, labels = inputs logits = model(images, training=True)
with tf.GradientTape() as tape:
logits = model(images, training=True) prediction_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits)
prediction_loss = tf.keras.losses.sparse_categorical_crossentropy( loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
labels, logits) num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
loss = tf.reduce_sum(prediction_loss) * (1.0/ flags_obj.batch_size)
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync if flags_obj.single_l2_loss_op:
filtered_variables = [
if flags_obj.single_l2_loss_op: tf.reshape(v, (-1,))
filtered_variables = [ for v in trainable_variables
tf.reshape(v, (-1,)) if 'bn' not in v.name
for v in trainable_variables ]
if 'bn' not in v.name l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
] tf.concat(filtered_variables, axis=0))
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss( loss += (l2_loss / num_replicas)
tf.concat(filtered_variables, axis=0)) else:
loss += (l2_loss / num_replicas) loss += (tf.reduce_sum(model.losses) / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas) # Scale the loss
# Scale the loss
if flags_obj.dtype == "fp16":
loss = optimizer.get_scaled_loss(loss)
grads = tape.gradient(loss, trainable_variables)
# Unscale the grads
if flags_obj.dtype == "fp16": if flags_obj.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads) loss = optimizer.get_scaled_loss(loss)
optimizer.apply_gradients(zip(grads, trainable_variables)) grads = tape.gradient(loss, trainable_variables)
training_accuracy.update_state(labels, logits) # Unscale the grads
return loss if flags_obj.dtype == "fp16":
grads = optimizer.get_unscaled_gradients(grads)
optimizer.apply_gradients(zip(grads, trainable_variables))
train_loss.update_state(loss)
training_accuracy.update_state(labels, logits)
@tf.function
def train_steps(iterator, steps):
"""Performs distributed training steps in a loop."""
for _ in tf.range(steps):
strategy.experimental_run_v2(step_fn, args=(next(iterator),))
def train_single_step(iterator):
if strategy: if strategy:
per_replica_losses = strategy.experimental_run_v2( strategy.experimental_run_v2(step_fn, args=(next(iterator),))
step_fn, args=(train_ds_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
else: else:
return step_fn(train_ds_inputs) return step_fn(next(iterator))
def test_step(test_ds_inputs): def test_step(iterator):
"""Evaluation StepFn.""" """Evaluation StepFn."""
def step_fn(inputs): def step_fn(inputs):
images, labels = inputs images, labels = inputs
...@@ -247,34 +300,39 @@ def run(flags_obj): ...@@ -247,34 +300,39 @@ def run(flags_obj):
test_accuracy.update_state(labels, logits) test_accuracy.update_state(labels, logits)
if strategy: if strategy:
strategy.experimental_run_v2(step_fn, args=(test_ds_inputs,)) strategy.experimental_run_v2(step_fn, args=(next(iterator),))
else: else:
step_fn(test_ds_inputs) step_fn(next(iterator))
if flags_obj.use_tf_function: if flags_obj.use_tf_function:
train_step = tf.function(train_step) train_single_step = tf.function(train_single_step)
test_step = tf.function(test_step) test_step = tf.function(test_step)
train_iter = iter(train_ds)
time_callback.on_train_begin() time_callback.on_train_begin()
for epoch in range(train_epochs): for epoch in range(train_epochs):
train_loss.reset_states()
train_iter = iter(train_ds)
total_loss = 0.0
training_accuracy.reset_states() training_accuracy.reset_states()
for step in range(train_steps): steps_in_current_epoch = 0
optimizer.lr = common.learning_rate_schedule( while steps_in_current_epoch < per_epoch_steps:
epoch, step, train_steps, flags_obj.batch_size) time_callback.on_batch_begin(
steps_in_current_epoch+epoch*per_epoch_steps)
time_callback.on_batch_begin(step+epoch*train_steps) steps = _steps_to_run(steps_in_current_epoch, per_epoch_steps,
total_loss += train_step(next(train_iter)) steps_per_loop)
time_callback.on_batch_end(step+epoch*train_steps) if steps == 1:
train_single_step(train_iter)
train_loss = total_loss / train_steps else:
logging.info('Training loss: %s, accuracy: %s%% at epoch: %d', # Converts steps to a Tensor to avoid tf.function retracing.
train_loss.numpy(), train_steps(train_iter, tf.convert_to_tensor(steps, dtype=tf.int32))
time_callback.on_batch_end(
steps_in_current_epoch+epoch*per_epoch_steps)
steps_in_current_epoch += steps
logging.info('Training loss: %s, accuracy: %s%% at epoch %d',
train_loss.result().numpy(),
training_accuracy.result().numpy(), training_accuracy.result().numpy(),
epoch) epoch + 1)
if (not flags_obj.skip_eval and if (not flags_obj.skip_eval and
(epoch + 1) % flags_obj.epochs_between_evals == 0): (epoch + 1) % flags_obj.epochs_between_evals == 0):
...@@ -283,12 +341,12 @@ def run(flags_obj): ...@@ -283,12 +341,12 @@ def run(flags_obj):
test_iter = iter(test_ds) test_iter = iter(test_ds)
for _ in range(eval_steps): for _ in range(eval_steps):
test_step(next(test_iter)) test_step(test_iter)
logging.info('Test loss: %s, accuracy: %s%% at epoch: %d', logging.info('Test loss: %s, accuracy: %s%% at epoch: %d',
test_loss.result().numpy(), test_loss.result().numpy(),
test_accuracy.result().numpy(), test_accuracy.result().numpy(),
epoch) epoch + 1)
time_callback.on_train_end() time_callback.on_train_end()
...@@ -297,7 +355,7 @@ def run(flags_obj): ...@@ -297,7 +355,7 @@ def run(flags_obj):
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_result = [test_loss.result().numpy(), eval_result = [test_loss.result().numpy(),
test_accuracy.result().numpy()] test_accuracy.result().numpy()]
train_result = [train_loss.numpy(), train_result = [train_loss.result().numpy(),
training_accuracy.result().numpy()] training_accuracy.result().numpy()]
stats = build_stats(train_result, eval_result, time_callback) stats = build_stats(train_result, eval_result, time_callback)
...@@ -307,7 +365,8 @@ def run(flags_obj): ...@@ -307,7 +365,8 @@ def run(flags_obj):
def main(_): def main(_):
model_helpers.apply_clean(flags.FLAGS) model_helpers.apply_clean(flags.FLAGS)
with logger.benchmark_context(flags.FLAGS): with logger.benchmark_context(flags.FLAGS):
return run(flags.FLAGS) stats = run(flags.FLAGS)
logging.info('Run stats:\n%s', stats)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -353,6 +353,13 @@ def define_keras_flags(dynamic_loss_scale=True): ...@@ -353,6 +353,13 @@ def define_keras_flags(dynamic_loss_scale=True):
flags.DEFINE_boolean( flags.DEFINE_boolean(
name='enable_checkpoint_and_export', default=False, name='enable_checkpoint_and_export', default=False,
help='Whether to enable a checkpoint callback and export the savedmodel.') help='Whether to enable a checkpoint callback and export the savedmodel.')
flags.DEFINE_string(
name='tpu', default='', help='TPU address to connect to.')
flags.DEFINE_integer(
name='steps_per_loop', default=1,
help='Number of steps per graph-mode loop. Only training step happens '
'inside the loop. Callbacks will not be called inside. Will be capped at '
'steps per epoch.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
......
...@@ -115,9 +115,9 @@ def process_record_dataset(dataset, ...@@ -115,9 +115,9 @@ def process_record_dataset(dataset,
if is_training: if is_training:
# Shuffles records before repeating to respect epoch boundaries. # Shuffles records before repeating to respect epoch boundaries.
dataset = dataset.shuffle(buffer_size=shuffle_buffer) dataset = dataset.shuffle(buffer_size=shuffle_buffer)
# Repeats the dataset for the number of epochs to train.
dataset = dataset.repeat()
# Repeats the dataset for the number of epochs to train.
dataset = dataset.repeat(num_epochs)
# Parses the raw records into images and labels. # Parses the raw records into images and labels.
dataset = dataset.map( dataset = dataset.map(
...@@ -133,10 +133,10 @@ def process_record_dataset(dataset, ...@@ -133,10 +133,10 @@ def process_record_dataset(dataset,
# on how many devices are present. # on how many devices are present.
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
if tf_data_experimental_slack: options = tf.data.Options()
options = tf.data.Options() options.experimental_slack = tf_data_experimental_slack
options.experimental_slack = True options.experimental_allow_stateful = True
dataset = dataset.with_options(options) dataset = dataset.with_options(options)
return dataset return dataset
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment