Commit 8deba73f authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Add classification_loss and localiztion_loss metrics for TPU jobs.

PiperOrigin-RevId: 191957195
parent aedfa2e4
......@@ -582,8 +582,8 @@ class SSDMetaArch(model.DetectionModel):
name='classification_loss')
loss_dict = {
localization_loss.op.name: localization_loss,
classification_loss.op.name: classification_loss
str(localization_loss.op.name): localization_loss,
str(classification_loss.op.name): classification_loss
}
return loss_dict
......
......@@ -297,8 +297,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
regularization_loss = tf.add_n(regularization_losses,
name='regularization_loss')
losses.append(regularization_loss)
if not use_tpu:
tf.summary.scalar('regularization_loss', regularization_loss)
losses_dict['Loss/regularization_loss'] = regularization_loss
total_loss = tf.add_n(losses, name='total_loss')
if mode == tf.estimator.ModeKeys.TRAIN:
......@@ -380,6 +379,8 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
eval_metrics, category_index.values(), eval_dict,
include_metrics_per_category=False)
for loss_key, loss_tensor in iter(losses_dict.items()):
eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
if img_summary is not None:
eval_metric_ops['Detections_Left_Groundtruth_Right'] = (
img_summary, tf.no_op())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment