Commit e8b0e950 authored by Vighnesh Birodkar's avatar Vighnesh Birodkar Committed by TF Object Detection Team
Browse files

Log loss tensor summaries after they are reduced on the coordinator.

PiperOrigin-RevId: 377187393
parent 5247a17e
......@@ -20,11 +20,11 @@ from __future__ import print_function
import copy
import os
import pprint
import time
import numpy as np
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow.compat.v2 as tf2
from object_detection import eval_util
from object_detection import inputs
......@@ -183,6 +183,22 @@ def _ensure_model_is_built(model, input_dataset, unpad_groundtruth_tensors):
))
def normalize_dict(values_dict, num_replicas):
num_replicas = tf.constant(num_replicas, dtype=tf.float32)
return {key: tf.math.divide(loss, num_replicas) for key, loss
in values_dict.items()}
def reduce_dict(strategy, reduction_dict, reduction_op):
# TODO(anjalisridhar): explore if it is safe to remove the # num_replicas
# scaling of the loss and switch this to a ReduceOp.Mean
return {
name: strategy.reduce(reduction_op, loss, axis=None)
for name, loss in reduction_dict.items()
}
# TODO(kaftan): Explore removing learning_rate from this method & returning
## The full losses dict instead of just total_loss, then doing all summaries
## saving in a utility method called by the outer training loop.
......@@ -192,10 +208,8 @@ def eager_train_step(detection_model,
labels,
unpad_groundtruth_tensors,
optimizer,
learning_rate,
add_regularization_loss=True,
clip_gradients_value=None,
global_step=None,
num_replicas=1.0):
"""Process a single training batch.
......@@ -266,16 +280,10 @@ def eager_train_step(detection_model,
float32 tensor containing the weights of the keypoint depth feature.
unpad_groundtruth_tensors: A parameter passed to unstack_batch.
optimizer: The training optimizer that will update the variables.
learning_rate: The learning rate tensor for the current training step.
This is used only for TensorBoard logging purposes, it does not affect
model training.
add_regularization_loss: Whether or not to include the model's
regularization loss in the losses dictionary.
clip_gradients_value: If this is present, clip the gradients global norm
at this value using `tf.clip_by_global_norm`.
global_step: The current training step. Used for TensorBoard logging
purposes. This step is not updated by this function and must be
incremented separately.
num_replicas: The number of replicas in the current distribution strategy.
This is used to scale the total loss so that training in a distribution
strategy works correctly.
......@@ -296,31 +304,18 @@ def eager_train_step(detection_model,
losses_dict, _ = _compute_losses_and_predictions_dicts(
detection_model, features, labels, add_regularization_loss)
total_loss = losses_dict['Loss/total_loss']
# Normalize loss for num replicas
total_loss = tf.math.divide(total_loss,
tf.constant(num_replicas, dtype=tf.float32))
losses_dict['Loss/normalized_total_loss'] = total_loss
for loss_type in losses_dict:
tf.compat.v2.summary.scalar(
loss_type, losses_dict[loss_type], step=global_step)
losses_dict = normalize_dict(losses_dict, num_replicas)
trainable_variables = detection_model.trainable_variables
total_loss = losses_dict['Loss/total_loss']
gradients = tape.gradient(total_loss, trainable_variables)
if clip_gradients_value:
gradients, _ = tf.clip_by_global_norm(gradients, clip_gradients_value)
optimizer.apply_gradients(zip(gradients, trainable_variables))
tf.compat.v2.summary.scalar('learning_rate', learning_rate, step=global_step)
tf.compat.v2.summary.image(
name='train_input_images',
step=global_step,
data=features[fields.InputDataFields.image],
max_outputs=3)
return total_loss
return losses_dict
def validate_tf_v2_checkpoint_restore_map(checkpoint_restore_map):
......@@ -479,7 +474,12 @@ def train_loop(
Checkpoint every n training steps.
checkpoint_max_to_keep:
int, the number of most recent checkpoints to keep in the model directory.
record_summaries: Boolean, whether or not to record summaries.
record_summaries: Boolean, whether or not to record summaries defined by
the model or the training pipeline. This does not impact the summaries
of the loss values which are always recorded. Examples of summaries
that are controlled by this flag include:
- Image summaries of training images.
- Intermediate tensors which maybe logged by meta architectures.
performance_summary_exporter: function for exporting performance metrics.
num_steps_per_iteration: int, The number of training steps to perform
in each iteration.
......@@ -538,7 +538,8 @@ def train_loop(
strategy = tf.compat.v2.distribute.get_strategy()
with strategy.scope():
detection_model = MODEL_BUILD_UTIL_MAP['detection_model_fn_base'](
model_config=model_config, is_training=True)
model_config=model_config, is_training=True,
add_summaries=record_summaries)
def train_dataset_fn(input_context):
"""Callable to create train input."""
......@@ -581,11 +582,9 @@ def train_loop(
# is the chief.
summary_writer_filepath = get_filepath(strategy,
os.path.join(model_dir, 'train'))
if record_summaries:
summary_writer = tf.compat.v2.summary.create_file_writer(
summary_writer_filepath)
else:
summary_writer = tf2.summary.create_noop_writer()
with summary_writer.as_default():
with strategy.scope():
......@@ -619,32 +618,37 @@ def train_loop(
def train_step_fn(features, labels):
"""Single train step."""
loss = eager_train_step(
if record_summaries:
tf.compat.v2.summary.image(
name='train_input_images',
step=global_step,
data=features[fields.InputDataFields.image],
max_outputs=3)
losses_dict = eager_train_step(
detection_model,
features,
labels,
unpad_groundtruth_tensors,
optimizer,
learning_rate=learning_rate_fn(),
add_regularization_loss=add_regularization_loss,
clip_gradients_value=clip_gradients_value,
global_step=global_step,
num_replicas=strategy.num_replicas_in_sync)
global_step.assign_add(1)
return loss
return losses_dict
def _sample_and_train(strategy, train_step_fn, data_iterator):
features, labels = data_iterator.next()
if hasattr(tf.distribute.Strategy, 'run'):
per_replica_losses = strategy.run(
per_replica_losses_dict = strategy.run(
train_step_fn, args=(features, labels))
else:
per_replica_losses = strategy.experimental_run_v2(
train_step_fn, args=(features, labels))
# TODO(anjalisridhar): explore if it is safe to remove the
## num_replicas scaling of the loss and switch this to a ReduceOp.Mean
return strategy.reduce(tf.distribute.ReduceOp.SUM,
per_replica_losses, axis=None)
per_replica_losses_dict = (
strategy.experimental_run_v2(
train_step_fn, args=(features, labels)))
return reduce_dict(
strategy, per_replica_losses_dict, tf.distribute.ReduceOp.SUM)
@tf.function
def _dist_train_step(data_iterator):
......@@ -670,7 +674,7 @@ def train_loop(
for _ in range(global_step.value(), train_steps,
num_steps_per_iteration):
loss = _dist_train_step(train_input_iter)
losses_dict = _dist_train_step(train_input_iter)
time_taken = time.time() - last_step_time
last_step_time = time.time()
......@@ -681,11 +685,19 @@ def train_loop(
steps_per_sec_list.append(steps_per_sec)
logged_dict = losses_dict.copy()
logged_dict['learning_rate'] = learning_rate_fn()
for key, val in logged_dict.items():
tf.compat.v2.summary.scalar(key, val, step=global_step)
if global_step.value() - logged_step >= 100:
logged_dict_np = {name: value.numpy() for name, value in
logged_dict.items()}
tf.logging.info(
'Step {} per-step time {:.3f}s loss={:.3f}'.format(
global_step.value(), time_taken / num_steps_per_iteration,
loss))
'Step {} per-step time {:.3f}s'.format(
global_step.value(), time_taken / num_steps_per_iteration))
tf.logging.info(pprint.pformat(logged_dict_np, width=40))
logged_step = global_step.value()
if ((int(global_step.value()) - checkpointed_step) >=
......@@ -704,7 +716,7 @@ def train_loop(
'steps_per_sec': np.mean(steps_per_sec_list),
'steps_per_sec_p50': np.median(steps_per_sec_list),
'steps_per_sec_max': max(steps_per_sec_list),
'last_batch_loss': float(loss)
'last_batch_loss': float(losses_dict['Loss/total_loss'])
}
mixed_precision = 'bf16' if kwargs['use_bfloat16'] else 'fp32'
performance_summary_exporter(metrics, mixed_precision)
......
......@@ -65,8 +65,10 @@ flags.DEFINE_integer(
flags.DEFINE_integer(
'checkpoint_every_n', 1000, 'Integer defining how often we checkpoint.')
flags.DEFINE_boolean('record_summaries', True,
('Whether or not to record summaries during'
' training.'))
('Whether or not to record summaries defined by the model'
' or the training pipeline. This does not impact the'
' summaries of the loss values which are always'
' recorded.'))
FLAGS = flags.FLAGS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment