Commit 8deba73f authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Add classification_loss and localiztion_loss metrics for TPU jobs.

PiperOrigin-RevId: 191957195
parent aedfa2e4
...@@ -582,8 +582,8 @@ class SSDMetaArch(model.DetectionModel): ...@@ -582,8 +582,8 @@ class SSDMetaArch(model.DetectionModel):
name='classification_loss') name='classification_loss')
loss_dict = { loss_dict = {
localization_loss.op.name: localization_loss, str(localization_loss.op.name): localization_loss,
classification_loss.op.name: classification_loss str(classification_loss.op.name): classification_loss
} }
return loss_dict return loss_dict
......
...@@ -297,8 +297,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False): ...@@ -297,8 +297,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
regularization_loss = tf.add_n(regularization_losses, regularization_loss = tf.add_n(regularization_losses,
name='regularization_loss') name='regularization_loss')
losses.append(regularization_loss) losses.append(regularization_loss)
if not use_tpu: losses_dict['Loss/regularization_loss'] = regularization_loss
tf.summary.scalar('regularization_loss', regularization_loss)
total_loss = tf.add_n(losses, name='total_loss') total_loss = tf.add_n(losses, name='total_loss')
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
...@@ -380,6 +379,8 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False): ...@@ -380,6 +379,8 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators( eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
eval_metrics, category_index.values(), eval_dict, eval_metrics, category_index.values(), eval_dict,
include_metrics_per_category=False) include_metrics_per_category=False)
for loss_key, loss_tensor in iter(losses_dict.items()):
eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
if img_summary is not None: if img_summary is not None:
eval_metric_ops['Detections_Left_Groundtruth_Right'] = ( eval_metric_ops['Detections_Left_Groundtruth_Right'] = (
img_summary, tf.no_op()) img_summary, tf.no_op())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment